Index: llvm/include/llvm/IR/IntrinsicsAArch64.td =================================================================== --- llvm/include/llvm/IR/IntrinsicsAArch64.td +++ llvm/include/llvm/IR/IntrinsicsAArch64.td @@ -3139,3 +3139,9 @@ def int_aarch64_sve_add_single_x2 : SME2_VG2_Multi_Single_Intrinsic; def int_aarch64_sve_add_single_x4 : SME2_VG4_Multi_Single_Intrinsic; } + +//===----------------------------------------------------------------------===// +// Inlining target math intrinsics + +def int_aarch64_cos : AdvSIMD_1FloatArg_Intrinsic; +def int_aarch64_sin : AdvSIMD_1FloatArg_Intrinsic; Index: llvm/lib/Target/AArch64/AArch64.h =================================================================== --- llvm/lib/Target/AArch64/AArch64.h +++ llvm/lib/Target/AArch64/AArch64.h @@ -70,6 +70,7 @@ FunctionPass *createAArch64PostSelectOptimize(); FunctionPass *createAArch64StackTaggingPass(bool IsOptNone); FunctionPass *createAArch64StackTaggingPreRAPass(); +FunctionPass *createAArch64InlineMathPass(); void initializeAArch64A53Fix835769Pass(PassRegistry&); void initializeAArch64A57FPLoadBalancingPass(PassRegistry&); @@ -84,6 +85,7 @@ void initializeAArch64DAGToDAGISelPass(PassRegistry &); void initializeAArch64DeadRegisterDefinitionsPass(PassRegistry&); void initializeAArch64ExpandPseudoPass(PassRegistry &); +void initializeAArch64InlineMathPass(PassRegistry &); void initializeAArch64KCFIPass(PassRegistry &); void initializeAArch64LoadStoreOptPass(PassRegistry&); void initializeAArch64LowerHomogeneousPrologEpilogPass(PassRegistry &); Index: llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp +++ llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp @@ -15,6 +15,7 @@ #include "MCTargetDesc/AArch64AddressingModes.h" #include "llvm/ADT/APSInt.h" #include "llvm/CodeGen/ISDOpcodes.h" +#include "llvm/CodeGen/MachineModuleInfo.h" #include "llvm/CodeGen/SelectionDAGISel.h" #include "llvm/IR/Function.h" // To access function attributes. #include "llvm/IR/GlobalValue.h" @@ -344,6 +345,10 @@ SDValue createTuple(ArrayRef Vecs, const unsigned RegClassIDs[], const unsigned SubRegs[]); + void SelectCosSVE(SDNode *N, unsigned IntNo, bool isFloat); + void SelectSinSVE(SDNode *N, unsigned IntNo, bool isFloat); + void SelectInlineMath(unsigned IntNo, SDNode *N, EVT VT); + void SelectTable(SDNode *N, unsigned NumVecs, unsigned Opc, bool isExt); bool tryIndexedLoad(SDNode *N); @@ -1509,6 +1514,457 @@ return SDValue(N, 0); } +#include "AArch64InlineMathSymbols.h" +#include "llvm/Support/Allocator.h" + +static SDValue getTargetGV(SelectionDAG *DAG, const Module *M, SDLoc DL, + const char *GVName) { + const GlobalVariable *GV = M->getNamedGlobal(GVName); + return DAG->getTargetGlobalAddress(GV, DL, MVT::i64, 0, AArch64II::MO_PAGE); +} + +static SDValue getTargetGVLow12(SelectionDAG *DAG, const Module *M, SDLoc DL, + const char *GVName) { + const GlobalVariable *GV = M->getNamedGlobal(GVName); + const unsigned Flag = AArch64II::MO_PAGEOFF | AArch64II::MO_NC; + return DAG->getTargetGlobalAddress(GV, DL, MVT::i64, 0, Flag); +} + +static void setMMOperand(SelectionDAG *DAG, SDNode *Node, EVT MemVT) { + MachineMemOperand *MemOp; + Align Alignment = DAG->getEVTAlign(MemVT); + MemOp = DAG->getMachineFunction().getMachineMemOperand( + MachinePointerInfo::getConstantPool(DAG->getMachineFunction()), + MachineMemOperand::MOLoad | MachineMemOperand::MOInvariant, + MemVT.getStoreSize().getFixedSize(), Alignment); + DAG->setNodeMemRefs(cast(Node), {MemOp}); +} + +static SDValue createPTRUE(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + bool isFloat = false) { + unsigned MI; + if (isFloat) { + MI = AArch64::PTRUE_S; + } else { + MI = AArch64::PTRUE_D; + } + return SDValue(DAG->getMachineNode(MI, DL, MVT::nxv2i1, Op1), 0); +} + +static SDValue createDUP_ZI(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::DUP_ZI_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::DUP_ZI_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2), 0); +} + +static SDValue createLD1R_IMM(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, uint64_t i, bool isFloat = false, + bool isInvariant = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::LD1RW_IMM; + } else { + VT = MVT::nxv2f64; + MI = AArch64::LD1RD_IMM; + } + SDNode *Node = DAG->getMachineNode(MI, DL, VT, Op1, Op2, + DAG->getTargetConstant(i, DL, MVT::i64)); + if (isInvariant) { + if (isFloat) { + setMMOperand(DAG, Node, MVT::f32); + } else { + setMMOperand(DAG, Node, MVT::f64); + } + } + return SDValue(Node, 0); +} + +static SDValue createFMAD_ZPmZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, SDValue Op3, SDValue Op4, + bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FMAD_ZPmZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FMAD_ZPmZZ_D; + } + SDValue Ops[] = {Op1, Op2, Op3, Op4}; + return SDValue(DAG->getMachineNode(MI, DL, VT, Ops), 0); +} + +static SDValue createFACGT_PPzZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, SDValue Op3, + bool isFloat = false) { + unsigned MI; + if (isFloat) { + MI = AArch64::FACGT_PPzZZ_S; + } else { + MI = AArch64::FACGT_PPzZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, MVT::nxv2i1, Op1, Op2, Op3), 0); +} + +static SDValue createFSUB_ZZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FSUB_ZZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FSUB_ZZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2), 0); +} + +static SDValue createFMSB_ZPmZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, SDValue Op3, SDValue Op4, + bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FMSB_ZPmZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FMSB_ZPmZZ_D; + } + SDValue Ops[] = {Op1, Op2, Op3, Op4}; + return SDValue(DAG->getMachineNode(MI, DL, VT, Ops), 0); +} + +static SDValue createFTSMUL_ZZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FTSMUL_ZZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FTSMUL_ZZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2), 0); +} + +static SDValue createFTSSEL_ZZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FTSSEL_ZZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FTSSEL_ZZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2), 0); +} + +static SDValue createFTMAD_ZZI(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, uint64_t i, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FTMAD_ZZI_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FTMAD_ZZI_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2, + DAG->getTargetConstant(i, DL, MVT::i64)), + 0); +} + +static SDValue createFMUL_ZZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::FMUL_ZZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::FMUL_ZZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2), 0); +} + +static SDValue createSEL_ZPZZ(SelectionDAG *DAG, SDLoc DL, SDValue Op1, + SDValue Op2, SDValue Op3, bool isFloat = false) { + EVT VT; + unsigned MI; + if (isFloat) { + VT = MVT::nxv4f32; + MI = AArch64::SEL_ZPZZ_S; + } else { + VT = MVT::nxv2f64; + MI = AArch64::SEL_ZPZZ_D; + } + return SDValue(DAG->getMachineNode(MI, DL, VT, Op1, Op2, Op3), 0); +} + +static SDValue createADDXri(SelectionDAG *DAG, SDLoc DL, SDValue Op, + SDValue Inc) { + return SDValue(DAG->getMachineNode(AArch64::ADDXri, DL, MVT::i64, Op, Inc, + DAG->getTargetConstant(0, DL, MVT::i64)), + 0); +} + +static SDValue createADRP(SelectionDAG *DAG, const Module *M, SDLoc DL, + const char *GVName) { + SDValue SD = getTargetGV(DAG, M, DL, GVName); + return SDValue(DAG->getMachineNode(AArch64::ADRP, DL, MVT::i64, SD), 0); +} + +void AArch64DAGToDAGISel::SelectCosSVE(SDNode *N, unsigned IntNo, + bool isFloat = false) { + assert(N->getNumValues() == 1 && "The number of values defined should be 1."); + assert(((N->getNumOperands() - 1) == 1 || (N->getNumOperands() - 1) == 2) && + "The number of values used should be 1 or 2."); + + SDLoc DL(N); + unsigned ArgIndex = IntNo == Intrinsic::aarch64_cos ? 1 : 2; + const SDValue &Arg = N->getOperand(ArgIndex); + const Module *M = CurDAG->getMachineFunction().getMMI().getModule(); + + // ptrue p0.t, ALL + SDValue P0 = createPTRUE( + CurDAG, DL, CurDAG->getTargetConstant(31, DL, MVT::i64), isFloat); + // adrp x0, .llvm.cos.nxv?f??.tbl + const char *TableName = isFloat ? SN_COS_NXV4F32_TBL : SN_COS_NXV2F64_TBL; + SDValue X0 = createADRP(CurDAG, M, DL, TableName); + // add x0, x0, :lo12:.llvm.cos.nxv?f??.tbl + SDValue Addr = getTargetGVLow12(CurDAG, M, DL, TableName); + X0 = createADDXri(CurDAG, DL, X0, Addr); + // fmov z2.t, 0.000000e+00 + SDValue Z2 = + createDUP_ZI(CurDAG, DL, CurDAG->getTargetConstant(0, DL, MVT::i64), + CurDAG->getTargetConstant(0, DL, MVT::i64), isFloat); + SDValue Z0; + SDValue Z5; + SDValue Z6; + SDValue Z3; + SDValue Z16; + SDValue Z4; + SDValue Z1; + if (isFloat) { + // ld1rd {z0.t}, p0/z, [x0] + Z0 = createLD1R_IMM(CurDAG, DL, P0, X0, 0, true, true); + // ld1rd {z5.t}, p0/z, [x0, 8] + Z5 = createLD1R_IMM(CurDAG, DL, P0, X0, 2, true, true); + // ld1rd {z6.t}, p0/z, [x0, 16] + Z6 = createLD1R_IMM(CurDAG, DL, P0, X0, 4, true, true); + // ld1rd {z3.t}, p0/z, [x0, 20] + Z3 = createLD1R_IMM(CurDAG, DL, P0, X0, 5, true, true); + // ld1rd {z16.t}, p0/z, [x0, 12] + Z16 = createLD1R_IMM(CurDAG, DL, P0, X0, 3, true, true); + // ld1rd {z4.t}, p0/z, [x0, 24] + Z4 = createLD1R_IMM(CurDAG, DL, P0, X0, 6, true, true); + // ld1rd {z1.t}, p0/z, [x0, 28] + Z1 = createLD1R_IMM(CurDAG, DL, P0, X0, 7, true, true); + } else { + // ld1rd {z5.t}, p0/z, [x0, 16] + Z5 = createLD1R_IMM(CurDAG, DL, P0, X0, 2, false, true); + // ld1rd {z0.t}, p0/z, [x0, 48] + Z0 = createLD1R_IMM(CurDAG, DL, P0, X0, 6, false, true); + // ld1rd {z6.t}, p0/z, [x0, 24] + Z6 = createLD1R_IMM(CurDAG, DL, P0, X0, 3, false, true); + // ld1rd {z3.t}, p0/z, [x0, 32] + Z3 = createLD1R_IMM(CurDAG, DL, P0, X0, 4, false, true); + // ld1rd {z16.t}, p0/z, [x0] + Z16 = createLD1R_IMM(CurDAG, DL, P0, X0, 0, false, true); + // ld1rd {z4.t}, p0/z, [x0, 40] + Z4 = createLD1R_IMM(CurDAG, DL, P0, X0, 5, false, true); + // ld1rd {z1.t}, p0/z, [x0, 56] + Z1 = createLD1R_IMM(CurDAG, DL, P0, X0, 7, false, true); + } + + // fmad z0.t, p1/m, z7.t, z5.t + Z0 = createFMAD_ZPmZZ(CurDAG, DL, P0, Z0, Arg, Z5, isFloat); + // facgt p0.t, p1/z, z7.t, z16.t + SDValue P1 = createFACGT_PPzZZ(CurDAG, DL, P0, Arg, Z16, isFloat); + // fsub z5.t, z0.t, z5.t + Z5 = createFSUB_ZZZ(CurDAG, DL, Z0, Z5, isFloat); + // fmsb z6.t, p1/m, z5.t, z7.t + Z6 = createFMSB_ZPmZZ(CurDAG, DL, P0, Z6, Z5, Arg, isFloat); + // fmsb z3.t, p1/m, z5.t, z6.t + Z3 = createFMSB_ZPmZZ(CurDAG, DL, P0, Z3, Z5, Z6, isFloat); + // fmsb z4.t, p1/m, z5.t, z3.t + Z4 = createFMSB_ZPmZZ(CurDAG, DL, P0, Z4, Z5, Z3, isFloat); + + // ftsmul z3.t, z4.t, z0.t + Z3 = createFTSMUL_ZZZ(CurDAG, DL, Z4, Z0, isFloat); + // ftssel z0.t, z4.t, z0.t + Z0 = createFTSSEL_ZZZ(CurDAG, DL, Z4, Z0, isFloat); + + if (!isFloat) { + // ftmad z2.t, z2.t, z3.t, 7 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 7); + // ftmad z2.t, z2.t, z3.t, 6 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 6); + // ftmad z2.t, z2.t, z3.t, 5 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 5); + } + // ftmad z2.t, z2.t, z3.t, 4 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 4, isFloat); + // ftmad z2.t, z2.t, z3.t, 3 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 3, isFloat); + // ftmad z2.t, z2.t, z3.t, 2 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 2, isFloat); + // ftmad z2.t, z2.t, z3.t, 1 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 1, isFloat); + // ftmad z2.t, z2.t, z3.t, 0 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 0, isFloat); + // fmul z0.t, z2.t, z0.t + Z0 = createFMUL_ZZZ(CurDAG, DL, Z2, Z0, isFloat); + // sel z0.t, p0, z1.t, z0.t + Z0 = createSEL_ZPZZ(CurDAG, DL, P1, Z1, Z0, isFloat); + + CurDAG->ReplaceAllUsesWith(N, &Z0); + ReplaceNode(N, Z0.getNode()); +} + +void AArch64DAGToDAGISel::SelectSinSVE(SDNode *N, unsigned IntNo, + bool isFloat = false) { + assert(N->getNumValues() == 1 && "The number of values defined should be 1."); + assert(((N->getNumOperands() - 1) == 1 || (N->getNumOperands() - 1) == 2) && + "The number of values used should be 1 or 2."); + + SDLoc DL(N); + unsigned ArgIndex = IntNo == Intrinsic::aarch64_sin ? 1 : 2; + const SDValue &Arg = N->getOperand(ArgIndex); + const Module *M = CurDAG->getMachineFunction().getMMI().getModule(); + + // adrp x0, .llvm.sin.nxv?f??.tbl + const char *TableName = isFloat ? SN_SIN_NXV4F32_TBL : SN_SIN_NXV2F64_TBL; + SDValue X0 = createADRP(CurDAG, M, DL, TableName); + // add x0, x0, :lo12:.llvm.sin.nxv?f??.tbl + SDValue Addr = getTargetGVLow12(CurDAG, M, DL, TableName); + X0 = createADDXri(CurDAG, DL, X0, Addr); + // ptrue p1.t, ALL + SDValue P1 = createPTRUE( + CurDAG, DL, CurDAG->getTargetConstant(31, DL, MVT::i64), isFloat); + // fmov z2.t, 0.000000e+00 + SDValue Z2 = + createDUP_ZI(CurDAG, DL, CurDAG->getTargetConstant(0, DL, MVT::i64), + CurDAG->getTargetConstant(0, DL, MVT::i64), isFloat); + SDValue Z0; + SDValue Z5; + SDValue Z6; + SDValue Z3; + SDValue Z16; + SDValue Z4; + SDValue Z1; + if (isFloat) { + // ld1rd {z0.t}, p0/z, [x0] + Z0 = createLD1R_IMM(CurDAG, DL, P1, X0, 0, true, true); + // ld1rd {z5.t}, p0/z, [x0, 4] + Z5 = createLD1R_IMM(CurDAG, DL, P1, X0, 1, true, true); + // ld1rd {z6.t}, p0/z, [x0, 16] + Z6 = createLD1R_IMM(CurDAG, DL, P1, X0, 4, true, true); + // ld1rd {z3.t}, p0/z, [x0, 20] + Z3 = createLD1R_IMM(CurDAG, DL, P1, X0, 5, true, true); + // ld1rd {z16.t}, p0/z, [x0, 12] + Z16 = createLD1R_IMM(CurDAG, DL, P1, X0, 3, true, true); + // ld1rd {z4.t}, p0/z, [x0, 24] + Z4 = createLD1R_IMM(CurDAG, DL, P1, X0, 6, true, true); + // ld1rd {z1.t}, p0/z, [x0, 28] + Z1 = createLD1R_IMM(CurDAG, DL, P1, X0, 7, true, true); + } else { + // ld1rd {z5.t}, p0/z, [x0, 8] + Z5 = createLD1R_IMM(CurDAG, DL, P1, X0, 1, false, true); + // ld1rd {z0.t}, p0/z, [x0, 48] + Z0 = createLD1R_IMM(CurDAG, DL, P1, X0, 6, false, true); + // ld1rd {z6.t}, p0/z, [x0, 24] + Z6 = createLD1R_IMM(CurDAG, DL, P1, X0, 3, false, true); + // ld1rd {z3.t}, p0/z, [x0, 32] + Z3 = createLD1R_IMM(CurDAG, DL, P1, X0, 4, false, true); + // ld1rd {z16.t}, p0/z, [x0] + Z16 = createLD1R_IMM(CurDAG, DL, P1, X0, 0, false, true); + // ld1rd {z4.t}, p0/z, [x0, 40] + Z4 = createLD1R_IMM(CurDAG, DL, P1, X0, 5, false, true); + // ld1rd {z1.t}, p0/z, [x0, 56] + Z1 = createLD1R_IMM(CurDAG, DL, P1, X0, 7, false, true); + } + + // fmad z0.t, p1/m, z7.t, z5.t + Z0 = createFMAD_ZPmZZ(CurDAG, DL, P1, Z0, Arg, Z5, isFloat); + // facgt p0.t, p1/z, z7.t, z16.t + SDValue P0 = createFACGT_PPzZZ(CurDAG, DL, P1, Arg, Z16, isFloat); + // fsub z5.t, z0.t, z5.t + Z5 = createFSUB_ZZZ(CurDAG, DL, Z0, Z5, isFloat); + // fmsb z6.t, p1/m, z5.t, z7.t + Z6 = createFMSB_ZPmZZ(CurDAG, DL, P1, Z6, Z5, Arg, isFloat); + // fmsb z3.t, p1/m, z5.t, z6.t + Z3 = createFMSB_ZPmZZ(CurDAG, DL, P1, Z3, Z5, Z6, isFloat); + // fmsb z4.t, p1/m, z5.t, z3.t + Z4 = createFMSB_ZPmZZ(CurDAG, DL, P1, Z4, Z5, Z3, isFloat); + + // ftsmul z3.t, z4.t, z0.t + Z3 = createFTSMUL_ZZZ(CurDAG, DL, Z4, Z0, isFloat); + // ftssel z0.t, z4.t, z0.t + Z0 = createFTSSEL_ZZZ(CurDAG, DL, Z4, Z0, isFloat); + + if (!isFloat) { + // ftmad z2.t, z2.t, z3.t, 7 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 7); + // ftmad z2.t, z2.t, z3.t, 6 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 6); + // ftmad z2.t, z2.t, z3.t, 5 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 5); + } + // ftmad z2.t, z2.t, z3.t, 4 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 4, isFloat); + // ftmad z2.t, z2.t, z3.t, 3 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 3, isFloat); + // ftmad z2.t, z2.t, z3.t, 2 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 2, isFloat); + // ftmad z2.t, z2.t, z3.t, 1 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 1, isFloat); + // ftmad z2.t, z2.t, z3.t, 0 + Z2 = createFTMAD_ZZI(CurDAG, DL, Z2, Z3, 0, isFloat); + // fmul z0.t, z2.t, z0.t + Z0 = createFMUL_ZZZ(CurDAG, DL, Z2, Z0, isFloat); + // sel z0.t, p0, z1.t, z0.t + Z0 = createSEL_ZPZZ(CurDAG, DL, P0, Z1, Z0, isFloat); + + CurDAG->ReplaceAllUsesWith(N, &Z0); + ReplaceNode(N, Z0.getNode()); +} + +void AArch64DAGToDAGISel::SelectInlineMath(unsigned IntNo, SDNode *N, EVT VT) { + if (!Subtarget->hasSVE()) { + llvm_unreachable("Unexpected intrinsic!"); + } + bool isFloat = VT == MVT::nxv4f32; + switch (IntNo) { + default: + llvm_unreachable("Unexpected intrinsic!"); + case Intrinsic::aarch64_cos: + SelectCosSVE(N, IntNo, isFloat); + return; + case Intrinsic::aarch64_sin: + SelectSinSVE(N, IntNo, isFloat); + return; + } +} + void AArch64DAGToDAGISel::SelectTable(SDNode *N, unsigned NumVecs, unsigned Opc, bool isExt) { SDLoc dl(N); @@ -4874,6 +5330,19 @@ switch (IntNo) { default: break; + case Intrinsic::aarch64_cos: + case Intrinsic::aarch64_sin: { + switch (VT.getSimpleVT().SimpleTy) { + default: + llvm_unreachable("Unexpected intrinsic type!"); + case MVT::nxv2f64: // for SVE version inline math + case MVT::nxv4f32: // for SVE version inline math + SelectInlineMath(IntNo, Node, VT.getSimpleVT().SimpleTy); + return; + } + return; + } + case Intrinsic::aarch64_tagp: SelectTagP(Node); return; Index: llvm/lib/Target/AArch64/AArch64InlineMathPass.cpp =================================================================== --- /dev/null +++ llvm/lib/Target/AArch64/AArch64InlineMathPass.cpp @@ -0,0 +1,343 @@ +//===-- AArch64InlineMathPass.cpp - AArch64 Inline Math Function pass --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Inline Math Function pass which inlines the specific +// math functions. +//===----------------------------------------------------------------------===// + +#include "AArch64InlineMathSymbols.h" +#include "AArch64TargetMachine.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/CodeGen/TargetLowering.h" +#include "llvm/CodeGen/TargetPassConfig.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/DebugInfo/Symbolize/Symbolize.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "aarch64-inline-math" + +namespace { + +class AArch64InlineMath : public FunctionPass { + +public: + static char ID; + + AArch64InlineMath() : FunctionPass(ID) { + initializeAArch64InlineMathPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired(); + } + + const char *getTargetInlineMathName(Intrinsic::ID ID, const Type *T) { + + if (T->isVectorTy()) { + T = dyn_cast(T)->getElementType(); + } else if (T->isStructTy()) { + // for complex return value + T = dyn_cast(T)->getElementType(0); + } + + Type::TypeID TyID = T->getTypeID(); + if (TyID != Type::DoubleTyID && TyID != Type::FloatTyID) + return ""; + + switch (ID) { + default: + llvm_unreachable("Unknown intrinsic"); + case Intrinsic::aarch64_cos: + return TyID == Type::DoubleTyID ? "cos" : "cosf"; + case Intrinsic::aarch64_sin: + return TyID == Type::DoubleTyID ? "sin" : "sinf"; + } + + return ""; + } + + Intrinsic::ID getInlineMathIntrinsicID(IntrinsicInst *II) const { + switch (II->getIntrinsicID()) { + default: + return Intrinsic::not_intrinsic; + case Intrinsic::cos: + return Intrinsic::aarch64_cos; + case Intrinsic::sin: + return Intrinsic::aarch64_sin; + } + } + + bool createSVESinorCosFTbl(const Intrinsic::ID IID, Module *M, Type *Ty) { + if (IID == Intrinsic::aarch64_sin && M->getNamedGlobal(SN_SIN_NXV4F32_TBL)) + return true; + + if (IID == Intrinsic::aarch64_cos && M->getNamedGlobal(SN_COS_NXV4F32_TBL)) + return true; + + Constant *SVESinorCosFTbl[] = { + ConstantInt::get(Ty, 1059256707), ConstantInt::get(Ty, 1262485504), + ConstantInt::get(Ty, 1262485505), ConstantInt::get(Ty, 1229516160), + ConstantInt::get(Ty, 1070141402), ConstantInt::get(Ty, 866263400), + ConstantInt::get(Ty, 667038912), ConstantInt::get(Ty, 2143289344), + }; + Constant *CA = ConstantArray::get(ArrayType::get(Ty, 8), SVESinorCosFTbl); + GlobalVariable *GV; + if (IID == Intrinsic::aarch64_sin) + GV = new GlobalVariable( + *M, CA->getType(), true, GlobalValue::InternalLinkage, CA, + Twine(SN_SIN_NXV4F32_TBL), nullptr, GlobalVariable::NotThreadLocal); + if (IID == Intrinsic::aarch64_cos) + GV = new GlobalVariable( + *M, CA->getType(), true, GlobalValue::InternalLinkage, CA, + Twine(SN_COS_NXV4F32_TBL), nullptr, GlobalVariable::NotThreadLocal); + return GV ? true : false; + } + + bool createSVESinorCosDTbl(const Intrinsic::ID IID, Module *M, Type *Ty) { + if (IID == Intrinsic::aarch64_sin && M->getNamedGlobal(SN_SIN_NXV2F64_TBL)) + return true; + + if (IID == Intrinsic::aarch64_cos && M->getNamedGlobal(SN_COS_NXV2F64_TBL)) + return true; + + Constant *SVESinorCosDTbl[] = { + ConstantInt::get(Ty, 0x43291508581d4000ULL), + ConstantInt::get(Ty, 0x4338000000000000ULL), + ConstantInt::get(Ty, 0x4338000000000001ULL), + ConstantInt::get(Ty, 0x3ff921fb50000000ULL), + ConstantInt::get(Ty, 0x3e5110b460000000ULL), + ConstantInt::get(Ty, 0x3c91a62633145c07ULL), + ConstantInt::get(Ty, 0x3fe45f306dc9c882ULL), + ConstantInt::get(Ty, 0x7ff8000000000000ULL), + }; + Constant *CA = ConstantArray::get(ArrayType::get(Ty, 8), SVESinorCosDTbl); + GlobalVariable *GV; + if (IID == Intrinsic::aarch64_sin) + GV = new GlobalVariable( + *M, CA->getType(), true, GlobalValue::InternalLinkage, CA, + Twine(SN_SIN_NXV2F64_TBL), nullptr, GlobalVariable::NotThreadLocal); + if (IID == Intrinsic::aarch64_cos) + GV = new GlobalVariable( + *M, CA->getType(), true, GlobalValue::InternalLinkage, CA, + Twine(SN_COS_NXV2F64_TBL), nullptr, GlobalVariable::NotThreadLocal); + return GV ? true : false; + } + + bool createSVEConstValueTable(Module *M, const Intrinsic::ID IID, + const Type *T) { + LLVMContext &Ctx = M->getContext(); + + if (T->getScalarType()->isFloatTy()) { + switch (IID) { + default: + llvm_unreachable("Unknown intrinsic"); + case Intrinsic::aarch64_sin: + case Intrinsic::aarch64_cos: + return createSVESinorCosFTbl(IID, M, Type::getInt32Ty(Ctx)); + } + } + + if (T->getScalarType()->isDoubleTy()) { + switch (IID) { + default: + llvm_unreachable("Unknown intrinsic"); + case Intrinsic::aarch64_sin: + case Intrinsic::aarch64_cos: + return createSVESinorCosDTbl(IID, M, Type::getInt64Ty(Ctx)); + } + } + llvm_unreachable("Invalid Type"); + } + + bool hasVolatile(IntrinsicInst *II) const { + for (auto &Args : dyn_cast(II)->args()) + if (auto *Load = dyn_cast(Args)) + if (Load->isVolatile()) { + LLVM_DEBUG(dbgs() << *Args << " is volatile. \n"); + return true; + } + for (auto *User : II->users()) + if (auto *Store = dyn_cast(User)) { + if (Store->isVolatile()) { + LLVM_DEBUG(dbgs() << *User << " is volatile. \n"); + return true; + } + } else if (auto *Extract = dyn_cast(User)) { + // %1 = call { double, double } @llvm.cexp.f64.f64(double %.real, double + // %.imag), !dbg !65 %2 = extractvalue { double, double } %1, 0, !dbg + // !65 %3 = extractvalue { double, double } %1, 1, !dbg !65 %4 = load { + // double, double }*, { double, double }** %RES.addr, align 8, !dbg !66, + // !tbaa !61 + // %.realp1 = getelementptr inbounds { double, double }, { double, + // double }* %4, i32 0, i32 0, !dbg !67 + // %.imagp2 = getelementptr inbounds { double, double }, { double, + // double }* %4, i32 0, i32 1, !dbg !67 store volatile double %2, + // double* %.realp1, align 8, !dbg !67 store volatile double %3, double* + // %.imagp2, align 8, !dbg !67 + for (auto *User2 : Extract->users()) { + if (auto *Store2 = dyn_cast(User2)) { + if (Store2->isVolatile()) { + LLVM_DEBUG(dbgs() << *User << " is volatile. \n"); + return true; + } + } + } + } + return false; + } + + bool isInlineMathTarget(Type *T, bool HasSVE) { + Type *ScalarType = T->getScalarType(); + if (VectorType *VT = dyn_cast(T)) { + if (!HasSVE) + return false; + + if (!(VT->getElementCount().isScalable())) + return false; + + if (!(ScalarType->isDoubleTy() || ScalarType->isFloatTy())) + return false; + + if (ScalarType->isDoubleTy() && + cast(VT)->getMinNumElements() != 2) + return false; + + if (ScalarType->isFloatTy() && + cast(VT)->getMinNumElements() != 4) + return false; + } else { + if (!HasSVE) + return false; + + if (!(ScalarType->isDoubleTy() || ScalarType->isFloatTy())) + return false; + } + return true; + } + + bool convertToInlineMath(Function &F, LoopInfo *LI, + OptimizationRemarkEmitter &ORE, bool HasSVE) { + + bool Changed = false; + + for (auto &BB : F) { + + if (LI->getLoopDepth(&BB) == 0) + continue; + + for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) { + Instruction *I = &*BI++; + + auto *II = dyn_cast(I); + Intrinsic::ID MathInt; + if (!II || (MathInt = getInlineMathIntrinsicID(II)) == + Intrinsic::not_intrinsic) + continue; + + llvm::FastMathFlags FMF = II->getFastMathFlags(); + if (!FMF.approxFunc()) + continue; + + // double, <2 x double>, + Type *T; + switch (MathInt) { + default: + T = II->getType(); + break; + } + + if (!isInlineMathTarget(T, HasSVE)) + continue; + + if (hasVolatile(II)) + continue; + + Changed = true; + + if (!HasSVE) { + return false; + } + + if (!createSVEConstValueTable(II->getModule(), MathInt, T)) { + return false; + } + + if (!T->isVectorTy()) { + return false; + } + + Function *Fn; + Fn = Intrinsic::getDeclaration(II->getModule(), MathInt, {T}); + Value *FnName = dyn_cast(Fn); + if (!FnName) + return false; + + II->setCalledFunction(Fn); + + // Report the InlineMath conversion. + ORE.emit([&]() { + const char *InlineMathVersionStr = + HasSVE ? "sve version of " : "neon version of "; + const char *VectorizedStr = (T->isVectorTy()) ? "vectorized " : ""; + const char *FunctionName = getTargetInlineMathName(MathInt, T); + return OptimizationRemark("inline-math", "inline-math", II) + << InlineMathVersionStr + << ore::NV("VectorizedStr", StringRef(VectorizedStr)) + << ore::NV("FunctionName", StringRef(FunctionName)) + << " inlined into " << FnName->getName().str(); + }); + } + } + + return Changed; + } + + bool runOnFunction(Function &F) override { + LLVM_DEBUG(dbgs() << "*** " << getPassName() << ": " << F.getName() + << "\n"); + + auto *TPC = getAnalysisIfAvailable(); + if (!TPC) + return false; + + const AArch64Subtarget *ST = + TPC->getTM().getSubtargetImpl(F); + if (!ST) + return false; + + LoopInfo *LI = &getAnalysis().getLoopInfo(); + OptimizationRemarkEmitter ORE(&F); + + bool Changed = false; + bool HasSVE = ST->hasSVE(); + + Changed |= convertToInlineMath(F, LI, ORE, HasSVE); + LLVM_DEBUG(dbgs() << "convertToInlineMath: " << Changed << "\n"); + + return Changed; + } +}; +} // end anonymous namespace. + +char AArch64InlineMath::ID = 0; +INITIALIZE_PASS(AArch64InlineMath, DEBUG_TYPE, "aarch64-inline-math", false, + false) + +FunctionPass *llvm::createAArch64InlineMathPass() { + return new AArch64InlineMath(); +} Index: llvm/lib/Target/AArch64/AArch64InlineMathSymbols.h =================================================================== --- /dev/null +++ llvm/lib/Target/AArch64/AArch64InlineMathSymbols.h @@ -0,0 +1,32 @@ +//===-- AArch64InlineMathSymbols.h - AArch64 Symbols for Inline Math --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines macros used by AArch64InlineMathPass.cpp. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_LIB_TARGET_AARCH64_AARCH64INLINEMATHSYMBOLS_H +#define LLVM_LIB_TARGET_AARCH64_AARCH64INLINEMATHSYMBOLS_H + +namespace llvm { + +//---------------------------------------------------------------------------// +// Symbol names for vectorized double precision mathematical functions +//---------------------------------------------------------------------------// +#define SN_SIN_NXV2F64_TBL ".llvm.sin.nxv2f64.tbl" +#define SN_COS_NXV2F64_TBL ".llvm.cos.nxv2f64.tbl" + +//---------------------------------------------------------------------------// +// Symbol names for vectorized single precision mathematical functions +//---------------------------------------------------------------------------// +#define SN_SIN_NXV4F32_TBL ".llvm.sin.nxv4f32.tbl" +#define SN_COS_NXV4F32_TBL ".llvm.cos.nxv4f32.tbl" + +} // namespace llvm + +#endif // LLVM_LIB_TARGET_AARCH64_AARCH64INLINEMATHSYMBOLS_H Index: llvm/lib/Target/AArch64/AArch64TargetMachine.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetMachine.cpp +++ llvm/lib/Target/AArch64/AArch64TargetMachine.cpp @@ -605,6 +605,10 @@ // mechanism specified in the SME ABI. addPass(createSMEABIPass()); + // Inline mathematical functions + if (TM->getOptLevel() != CodeGenOpt::None) + addPass(createAArch64InlineMathPass()); + // Add Control Flow Guard checks. if (TM->getTargetTriple().isOSWindows()) addPass(createCFGuardCheckPass()); Index: llvm/lib/Target/AArch64/CMakeLists.txt =================================================================== --- llvm/lib/Target/AArch64/CMakeLists.txt +++ llvm/lib/Target/AArch64/CMakeLists.txt @@ -61,6 +61,7 @@ AArch64RedundantCopyElimination.cpp AArch64ISelDAGToDAG.cpp AArch64ISelLowering.cpp + AArch64InlineMathPass.cpp AArch64InstrInfo.cpp AArch64KCFI.cpp AArch64LoadStoreOptimizer.cpp