diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -743,6 +743,7 @@ setTargetDAGCombine(ISD::INTRINSIC_VOID); setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); + setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::GlobalAddress); @@ -11580,6 +11581,47 @@ return ResultHADD; } +static SDValue performExtractVectorEltCombine(SDNode *N, SelectionDAG &DAG) { + SDValue N0 = N->getOperand(0), N1 = N->getOperand(1); + ConstantSDNode *ConstantN1 = dyn_cast(N1); + + EVT VT = N->getValueType(0); + + // Rewrite for pairwise fadd pattern + // (f32 (extract_vector_elt + // (fadd (vXf32 Other) + // (vector_shuffle (vXf32 Other) undef <1,X,...> )) 0)) + // -> + // (f32 (fadd (extract_vector_elt (vXf32 Other) 0) + // (extract_vector_elt (vXf32 Other) 1)) + if (ConstantN1 && ConstantN1->getZExtValue() == 0 && + N0->getOpcode() == ISD::FADD && VT == MVT::f32) { + SDLoc DL(N0); + SDValue N00 = N0->getOperand(0); + SDValue N01 = N0->getOperand(1); + + ShuffleVectorSDNode *Shuffle = dyn_cast(N01); + SDValue Other = N00; + + // And handle the commutative case. + if (!Shuffle) { + Shuffle = dyn_cast(N00); + Other = N01; + } + + if (Shuffle && Shuffle->getMaskElt(0) == 1 && + Other == Shuffle->getOperand(0)) { + return DAG.getNode(N0->getOpcode(), DL, VT, + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(0, SDLoc(N), MVT::i64)), + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(1, SDLoc(N), MVT::i64))); + } + } + + return SDValue(); +} + static SDValue performConcatVectorsCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, SelectionDAG &DAG) { @@ -14403,6 +14445,8 @@ return performUzpCombine(N, DAG); case ISD::INSERT_VECTOR_ELT: return performPostLD1Combine(N, DCI, true); + case ISD::EXTRACT_VECTOR_ELT: + return performExtractVectorEltCombine(N, DAG); case ISD::INTRINSIC_VOID: case ISD::INTRINSIC_W_CHAIN: switch (cast(N->getOperand(1))->getZExtValue()) { diff --git a/llvm/test/CodeGen/AArch64/faddp.ll b/llvm/test/CodeGen/AArch64/faddp.ll --- a/llvm/test/CodeGen/AArch64/faddp.ll +++ b/llvm/test/CodeGen/AArch64/faddp.ll @@ -5,9 +5,7 @@ ; CHECK-LABEL: faddp_2x: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 -; CHECK-NEXT: dup v1.2s, v0.s[1] -; CHECK-NEXT: fadd v0.2s, v0.2s, v1.2s -; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0 +; CHECK-NEXT: faddp s0, v0.2s ; CHECK-NEXT: ret entry: %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> @@ -19,9 +17,7 @@ define float @faddp_4x(<4 x float> %a) { ; CHECK-LABEL: faddp_4x: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: dup v1.4s, v0.s[1] -; CHECK-NEXT: fadd v0.4s, v0.4s, v1.4s -; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0 +; CHECK-NEXT: faddp s0, v0.2s ; CHECK-NEXT: ret entry: %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> @@ -33,9 +29,7 @@ define float @faddp_4x_commute(<4 x float> %a) { ; CHECK-LABEL: faddp_4x_commute: ; CHECK: // %bb.0: // %entry -; CHECK-NEXT: dup v1.4s, v0.s[1] -; CHECK-NEXT: fadd v0.4s, v1.4s, v0.4s -; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0 +; CHECK-NEXT: faddp s0, v0.2s ; CHECK-NEXT: ret entry: %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> @@ -48,9 +42,7 @@ ; CHECK-LABEL: faddp_2x_commute: ; CHECK: // %bb.0: // %entry ; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 -; CHECK-NEXT: dup v1.2s, v0.s[1] -; CHECK-NEXT: fadd v0.2s, v1.2s, v0.2s -; CHECK-NEXT: // kill: def $s0 killed $s0 killed $q0 +; CHECK-NEXT: faddp s0, v0.2s ; CHECK-NEXT: ret entry: %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32>