diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14228,6 +14228,7 @@ static bool hasPairwiseAdd(unsigned Opcode, EVT VT, bool FullFP16) { switch (Opcode) { + case ISD::STRICT_FADD: case ISD::FADD: return (FullFP16 && VT == MVT::f16) || VT == MVT::f32 || VT == MVT::f64; case ISD::ADD: @@ -14244,6 +14245,7 @@ EVT VT = N->getValueType(0); const bool FullFP16 = static_cast(DAG.getSubtarget()).hasFullFP16(); + bool IsStrict = N0->isStrictFPOpcode(); // Rewrite for pairwise fadd pattern // (f32 (extract_vector_elt @@ -14252,11 +14254,14 @@ // -> // (f32 (fadd (extract_vector_elt (vXf32 Other) 0) // (extract_vector_elt (vXf32 Other) 1)) + // For strict_fadd we need to make sure the old strict_fadd can be deleted, so + // we can only do this when it's used only by the extract_vector_elt. if (ConstantN1 && ConstantN1->getZExtValue() == 0 && - hasPairwiseAdd(N0->getOpcode(), VT, FullFP16)) { + hasPairwiseAdd(N0->getOpcode(), VT, FullFP16) && + (!IsStrict || N0.hasOneUse())) { SDLoc DL(N0); - SDValue N00 = N0->getOperand(0); - SDValue N01 = N0->getOperand(1); + SDValue N00 = N0->getOperand(IsStrict ? 1 : 0); + SDValue N01 = N0->getOperand(IsStrict ? 2 : 1); ShuffleVectorSDNode *Shuffle = dyn_cast(N01); SDValue Other = N00; @@ -14269,11 +14274,23 @@ if (Shuffle && Shuffle->getMaskElt(0) == 1 && Other == Shuffle->getOperand(0)) { - return DAG.getNode(N0->getOpcode(), DL, VT, - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, - DAG.getConstant(0, DL, MVT::i64)), - DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, - DAG.getConstant(1, DL, MVT::i64))); + SDValue Extract1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(0, DL, MVT::i64)); + SDValue Extract2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Other, + DAG.getConstant(1, DL, MVT::i64)); + if (!IsStrict) + return DAG.getNode(N0->getOpcode(), DL, VT, Extract1, Extract2); + + // For strict_fadd we need uses of the final extract_vector to be replaced + // with the strict_fadd, but we also need uses of the chain output of the + // original strict_fadd to use the chain output of the new strict_fadd as + // otherwise it may not be deleted. + SDValue Ret = DAG.getNode(N0->getOpcode(), DL, + {VT, MVT::Other}, + {N0->getOperand(0), Extract1, Extract2}); + DAG.ReplaceAllUsesOfValueWith(SDValue(N, 0), Ret); + DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Ret.getValue(1)); + return SDValue(N, 0); } } diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td --- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -8100,17 +8100,17 @@ def : Pat<(i64 (add (vector_extract (v2i64 FPR128:$Rn), (i64 0)), (vector_extract (v2i64 FPR128:$Rn), (i64 1)))), (i64 (ADDPv2i64p (v2i64 FPR128:$Rn)))>; -def : Pat<(f64 (fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)), - (vector_extract (v2f64 FPR128:$Rn), (i64 1)))), +def : Pat<(f64 (any_fadd (vector_extract (v2f64 FPR128:$Rn), (i64 0)), + (vector_extract (v2f64 FPR128:$Rn), (i64 1)))), (f64 (FADDPv2i64p (v2f64 FPR128:$Rn)))>; // vector_extract on 64-bit vectors gets promoted to a 128 bit vector, // so we match on v4f32 here, not v2f32. This will also catch adding // the low two lanes of a true v4f32 vector. -def : Pat<(fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)), - (vector_extract (v4f32 FPR128:$Rn), (i64 1))), +def : Pat<(any_fadd (vector_extract (v4f32 FPR128:$Rn), (i64 0)), + (vector_extract (v4f32 FPR128:$Rn), (i64 1))), (f32 (FADDPv2i32p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>; -def : Pat<(fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)), - (vector_extract (v8f16 FPR128:$Rn), (i64 1))), +def : Pat<(any_fadd (vector_extract (v8f16 FPR128:$Rn), (i64 0)), + (vector_extract (v8f16 FPR128:$Rn), (i64 1))), (f16 (FADDPv2i16p (EXTRACT_SUBREG FPR128:$Rn, dsub)))>; // Scalar 64-bit shifts in FPR64 registers. diff --git a/llvm/test/CodeGen/AArch64/faddp.ll b/llvm/test/CodeGen/AArch64/faddp.ll --- a/llvm/test/CodeGen/AArch64/faddp.ll +++ b/llvm/test/CodeGen/AArch64/faddp.ll @@ -100,3 +100,83 @@ %1 = extractelement <2 x i64> %0, i32 0 ret i64 %1 } + +define float @faddp_2xfloat_strict(<2 x float> %a) #0 { +; CHECK-LABEL: faddp_2xfloat_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: faddp s0, v0.2s +; CHECK-NEXT: ret +entry: + %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> + %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %a, <2 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <2 x float> %0, i32 0 + ret float %1 +} + +define float @faddp_4xfloat_strict(<4 x float> %a) #0 { +; CHECK-LABEL: faddp_4xfloat_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: faddp s0, v0.2s +; CHECK-NEXT: ret +entry: + %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> + %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %a, <4 x float> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <4 x float> %0, i32 0 + ret float %1 +} + +define float @faddp_4xfloat_commute_strict(<4 x float> %a) #0 { +; CHECK-LABEL: faddp_4xfloat_commute_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: faddp s0, v0.2s +; CHECK-NEXT: ret +entry: + %shift = shufflevector <4 x float> %a, <4 x float> undef, <4 x i32> + %0 = call <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float> %shift, <4 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <4 x float> %0, i32 0 + ret float %1 +} + +define float @faddp_2xfloat_commute_strict(<2 x float> %a) #0 { +; CHECK-LABEL: faddp_2xfloat_commute_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0 +; CHECK-NEXT: faddp s0, v0.2s +; CHECK-NEXT: ret +entry: + %shift = shufflevector <2 x float> %a, <2 x float> undef, <2 x i32> + %0 = call <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float> %shift, <2 x float> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <2 x float> %0, i32 0 + ret float %1 +} + +define double @faddp_2xdouble_strict(<2 x double> %a) #0 { +; CHECK-LABEL: faddp_2xdouble_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: faddp d0, v0.2d +; CHECK-NEXT: ret +entry: + %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> + %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %a, <2 x double> %shift, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <2 x double> %0, i32 0 + ret double %1 +} + +define double @faddp_2xdouble_commute_strict(<2 x double> %a) #0 { +; CHECK-LABEL: faddp_2xdouble_commute_strict: +; CHECK: // %bb.0: // %entry +; CHECK-NEXT: faddp d0, v0.2d +; CHECK-NEXT: ret +entry: + %shift = shufflevector <2 x double> %a, <2 x double> undef, <2 x i32> + %0 = call <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double> %shift, <2 x double> %a, metadata !"round.tonearest", metadata !"fpexcept.strict") #0 + %1 = extractelement <2 x double> %0, i32 0 + ret double %1 +} + +attributes #0 = { strictfp } + +declare <2 x float> @llvm.experimental.constrained.fadd.v2f32(<2 x float>, <2 x float>, metadata, metadata) +declare <4 x float> @llvm.experimental.constrained.fadd.v4f32(<4 x float>, <4 x float>, metadata, metadata) +declare <2 x double> @llvm.experimental.constrained.fadd.v2f64(<2 x double>, <2 x double>, metadata, metadata)