Index: llvm/include/llvm/CodeGen/ISDOpcodes.h =================================================================== --- llvm/include/llvm/CodeGen/ISDOpcodes.h +++ llvm/include/llvm/CodeGen/ISDOpcodes.h @@ -485,6 +485,9 @@ /// separately rounded operations. FMAD, + /// FCMA - Perform complex a * b + c with no intermediate rounding step. + FCMA, + /// FCOPYSIGN(X, Y) - Return the value of X with the sign of Y. NOTE: This /// DAG node does not require that X and Y have the same type, just that /// they are both floating point. X and the result must have the same type. Index: llvm/include/llvm/IR/Intrinsics.td =================================================================== --- llvm/include/llvm/IR/Intrinsics.td +++ llvm/include/llvm/IR/Intrinsics.td @@ -1997,6 +1997,12 @@ [llvm_anyvector_ty]>; } +//===----- Complex math intrinsics ----------------------------------------===// + +def int_fcmuladd: DefaultAttrsIntrinsic<[llvm_anyfloat_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, + LLVMMatchType<0>]>; + //===----- Matrix intrinsics ---------------------------------------------===// def int_matrix_transpose Index: llvm/include/llvm/Target/TargetSelectionDAG.td =================================================================== --- llvm/include/llvm/Target/TargetSelectionDAG.td +++ llvm/include/llvm/Target/TargetSelectionDAG.td @@ -481,6 +481,7 @@ def frem : SDNode<"ISD::FREM" , SDTFPBinOp>; def fma : SDNode<"ISD::FMA" , SDTFPTernaryOp, [SDNPCommutative]>; def fmad : SDNode<"ISD::FMAD" , SDTFPTernaryOp, [SDNPCommutative]>; +def fcma : SDNode<"ISD::FCMA" , SDTFPTernaryOp, [SDNPCommutative]>; def fabs : SDNode<"ISD::FABS" , SDTFPUnaryOp>; def fminnum : SDNode<"ISD::FMINNUM" , SDTFPBinOp, [SDNPCommutative, SDNPAssociative]>; Index: llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp @@ -4816,6 +4816,7 @@ Results.push_back(Tmp1.getValue(1)); break; case ISD::FMA: + case ISD::FCMA: Tmp1 = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Node->getOperand(0)); Tmp2 = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Node->getOperand(1)); Tmp3 = DAG.getNode(ISD::FP_EXTEND, dl, NVT, Node->getOperand(2)); Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp @@ -385,6 +385,7 @@ case ISD::FP_ROUND: case ISD::FP_EXTEND: case ISD::FMA: + case ISD::FCMA: case ISD::SIGN_EXTEND_INREG: case ISD::ANY_EXTEND_VECTOR_INREG: case ISD::SIGN_EXTEND_VECTOR_INREG: Index: llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp +++ llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp @@ -157,6 +157,7 @@ R = ScalarizeVecRes_BinOp(N); break; case ISD::FMA: + case ISD::FCMA: case ISD::FSHL: case ISD::FSHR: R = ScalarizeVecRes_TernaryOp(N); @@ -1119,6 +1120,7 @@ SplitVecRes_BinOp(N, Lo, Hi); break; case ISD::FMA: case ISD::VP_FMA: + case ISD::FCMA: case ISD::FSHL: case ISD::VP_FSHL: case ISD::FSHR: @@ -4128,6 +4130,7 @@ Res = WidenVecRes_Unary(N); break; case ISD::FMA: case ISD::VP_FMA: + case ISD::FCMA: case ISD::FSHL: case ISD::VP_FSHL: case ISD::FSHR: Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -4818,6 +4818,7 @@ case ISD::FSIN: case ISD::FCOS: case ISD::FMA: + case ISD::FCMA: case ISD::FMAD: { if (SNaN) return true; @@ -6613,6 +6614,13 @@ "Operand is DELETED_NODE!"); // Perform various simplifications. switch (Opcode) { + case ISD::FCMA: { + assert(VT.isFloatingPoint() && "This operator only applies to FP types!"); + assert(N1.getValueType() == VT && N2.getValueType() == VT && + N3.getValueType() == VT && "FCMA types must match!"); + // TODO : constant folding. + break; + } case ISD::FMA: { assert(VT.isFloatingPoint() && "This operator only applies to FP types!"); assert(N1.getValueType() == VT && N2.getValueType() == VT && Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -6477,6 +6477,14 @@ } return; } + case Intrinsic::fcmuladd: { + setValue(&I, DAG.getNode(ISD::FCMA, sdl, + getValue(I.getArgOperand(0)).getValueType(), + getValue(I.getArgOperand(0)), + getValue(I.getArgOperand(1)), + getValue(I.getArgOperand(2)), Flags)); + return; + } case Intrinsic::convert_to_fp16: setValue(&I, DAG.getNode(ISD::BITCAST, sdl, MVT::i16, DAG.getNode(ISD::FP_ROUND, sdl, MVT::f16, Index: llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp +++ llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp @@ -263,6 +263,7 @@ case ISD::FDIV: return "fdiv"; case ISD::STRICT_FDIV: return "strict_fdiv"; case ISD::FMA: return "fma"; + case ISD::FCMA: return "fcma"; case ISD::STRICT_FMA: return "strict_fma"; case ISD::FMAD: return "fmad"; case ISD::FREM: return "frem"; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1562,6 +1562,7 @@ setOperationAction(ISD::SHL, VT, Custom); setOperationAction(ISD::OR, VT, Custom); setOperationAction(ISD::SETCC, VT, Custom); + setOperationAction(ISD::FCMA, VT, Custom); setOperationAction(ISD::CONCAT_VECTORS, VT, Legal); setOperationAction(ISD::SELECT, VT, Expand); @@ -5912,6 +5913,19 @@ return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMUL_PRED); case ISD::FMA: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FMA_PRED); + case ISD::FCMA: { + SDLoc dl(Op); + SDValue vcmla0 = + DAG.getTargetConstant(Intrinsic::aarch64_neon_vcmla_rot0, dl, MVT::i64); + SDValue vcmla90 = DAG.getTargetConstant(Intrinsic::aarch64_neon_vcmla_rot90, + dl, MVT::i64); + SDValue Part1 = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), vcmla0, + Op.getOperand(2), Op.getOperand(0), Op.getOperand(1)); + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), vcmla90, + Part1, Op.getOperand(0), Op.getOperand(1)); + } + case ISD::FDIV: return LowerToPredicatedOp(Op, DAG, AArch64ISD::FDIV_PRED); case ISD::FNEG: Index: llvm/test/CodeGen/AArch64/complex-intrinsics.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/complex-intrinsics.ll @@ -0,0 +1,67 @@ +; RUN: llc < %s -asm-verbose=false -mtriple=arm64-eabi -aarch64-neon-syntax=apple -mattr="+complxnum" | FileCheck %s + +define <2 x float> @test_v2f32(<2 x float>* %A, <2 x float>* %B, <2 x float>* %C) nounwind { +;CHECK-LABEL: test_v2f32: +;CHECK: fcmla.2s{{.*}}#0 +;CHECK: fcmla.2s{{.*}}#90 +;CHECK-NOT: fcmla.2s + %tmp1 = load <2 x float>, <2 x float>* %A + %tmp2 = load <2 x float>, <2 x float>* %B + %tmp3 = load <2 x float>, <2 x float>* %C + %tmp4 = call <2 x float> @llvm.fcmuladd.v2f32(<2 x float> %tmp1, <2 x float> %tmp2, <2 x float> %tmp3) + ret <2 x float> %tmp4 +} + +define <4 x float> @test_v4f32(<4 x float> %A, <4 x float> %B, <4 x float> %C) nounwind { +;CHECK-LABEL: test_v4f32: +;CHECK: fcmla.4s{{.*}}#0 +;CHECK: fcmla.4s{{.*}}#90 +;CHECK-NOT: fcmla.4s + %tmp4 = call <4 x float> @llvm.fcmuladd.v4f32(<4 x float> %A, <4 x float> %B, <4 x float> %C) + ret <4 x float> %tmp4 +} + +define <8 x float> @test_v8f32(<8 x float>* %A, <8 x float>* %B, <8 x float>* %C) nounwind { +;CHECK-LABEL: test_v8f32: +;CHECK: fcmla.4s{{.*}}#0 +;CHECK: fcmla.4s{{.*}}#0 +;CHECK: fcmla.4s{{.*}}#90 +;CHECK: fcmla.4s{{.*}}#90 +;CHECK-NOT: fcmla.4s + %tmp1 = load <8 x float>, <8 x float>* %A + %tmp2 = load <8 x float>, <8 x float>* %B + %tmp3 = load <8 x float>, <8 x float>* %C + %tmp4 = call <8 x float> @llvm.fcmuladd.v8f32(<8 x float> %tmp1, <8 x float> %tmp2, <8 x float> %tmp3) + ret <8 x float> %tmp4 +} + +define <2 x double> @test_v2f64(<2 x double>* %A, <2 x double>* %B, <2 x double>* %C) nounwind { +;CHECK-LABEL: test_v2f64: +;CHECK: fcmla.2d{{.*}}#0 +;CHECK: fcmla.2d{{.*}}#90 +;CHECK-NOT: fcmla.2d + %tmp1 = load <2 x double>, <2 x double>* %A + %tmp2 = load <2 x double>, <2 x double>* %B + %tmp3 = load <2 x double>, <2 x double>* %C + %tmp4 = call <2 x double> @llvm.fcmuladd.v2f64(<2 x double> %tmp1, <2 x double> %tmp2, <2 x double> %tmp3) + ret <2 x double> %tmp4 +} + +define <4 x double> @test_v4f64(<4 x double>* %A, <4 x double>* %B, <4 x double>* %C) nounwind { +;CHECK-LABEL: test_v4f64: +;CHECK: fcmla.2d{{.*}}#0 +;CHECK: fcmla.2d{{.*}}#0 +;CHECK: fcmla.2d{{.*}}#90 +;CHECK: fcmla.2d{{.*}}#90 +;CHECK-NOT: fcmla.2d + %tmp1 = load <4 x double>, <4 x double>* %A + %tmp2 = load <4 x double>, <4 x double>* %B + %tmp3 = load <4 x double>, <4 x double>* %C + %tmp4 = call <4 x double> @llvm.fcmuladd.v4f64(<4 x double> %tmp1, <4 x double> %tmp2, <4 x double> %tmp3) + ret <4 x double> %tmp4 +} +declare <2 x float> @llvm.fcmuladd.v2f32(<2 x float>, <2 x float>, <2 x float>) nounwind readnone +declare <4 x float> @llvm.fcmuladd.v4f32(<4 x float>, <4 x float>, <4 x float>) nounwind readnone +declare <8 x float> @llvm.fcmuladd.v8f32(<8 x float>, <8 x float>, <8 x float>) nounwind readnone +declare <2 x double> @llvm.fcmuladd.v2f64(<2 x double>, <2 x double>, <2 x double>) nounwind readnone +declare <4 x double> @llvm.fcmuladd.v4f64(<4 x double>, <4 x double>, <4 x double>) nounwind readnone Index: mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td =================================================================== --- mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td +++ mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td @@ -96,6 +96,7 @@ def LLVM_FFloorOp : LLVM_UnaryIntrOpF<"floor">; def LLVM_FMAOp : LLVM_TernarySameArgsIntrOpF<"fma">; def LLVM_FMulAddOp : LLVM_TernarySameArgsIntrOpF<"fmuladd">; +def LLVM_FCMulAddOp : LLVM_TernarySameArgsIntrOpF<"fcmuladd">; def LLVM_Log10Op : LLVM_UnaryIntrOpF<"log10">; def LLVM_Log2Op : LLVM_UnaryIntrOpF<"log2">; def LLVM_LogOp : LLVM_UnaryIntrOpF<"log">;