Index: include/llvm/CodeGen/SelectionDAGNodes.h =================================================================== --- include/llvm/CodeGen/SelectionDAGNodes.h +++ include/llvm/CodeGen/SelectionDAGNodes.h @@ -1714,6 +1714,14 @@ ConstantFPSDNode * getConstantFPSplatNode(BitVector *UndefElements = nullptr) const; + /// \brief If this is a constant FP splat and the splatted constant FP is an + /// exact power or 2, return the log base 2 integer value. Otherwise, + /// return -1. + /// + /// The BitWidth specifies the necessary bit precision. + int32_t getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements, + uint32_t BitWidth) const; + bool isConstant() const; static inline bool classof(const SDNode *N) { Index: lib/CodeGen/SelectionDAG/SelectionDAG.cpp =================================================================== --- lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -13,6 +13,7 @@ #include "llvm/CodeGen/SelectionDAG.h" #include "SDNodeDbgValue.h" +#include "llvm/ADT/APSInt.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" @@ -7190,6 +7191,24 @@ return dyn_cast_or_null(getSplatValue(UndefElements)); } +int32_t +BuildVectorSDNode::getConstantFPSplatPow2ToLog2Int(BitVector *UndefElements, + uint32_t BitWidth) const { + if (ConstantFPSDNode *CN = + dyn_cast_or_null(getSplatValue(UndefElements))) { + bool IsExact; + APSInt IntVal(BitWidth); + APFloat APF = CN->getValueAPF(); + if (APF.convertToInteger(IntVal, APFloat::rmTowardZero, &IsExact) != + APFloat::opOK || + !IsExact) + return -1; + + return IntVal.exactLogBase2(); + } + return -1; +} + bool BuildVectorSDNode::isConstant() const { for (const SDValue &Op : op_values()) { unsigned Opc = Op.getOpcode(); Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -478,6 +478,10 @@ setTargetDAGCombine(ISD::SINT_TO_FP); setTargetDAGCombine(ISD::UINT_TO_FP); + setTargetDAGCombine(ISD::FP_TO_SINT); + setTargetDAGCombine(ISD::FP_TO_UINT); + setTargetDAGCombine(ISD::FDIV); + setTargetDAGCombine(ISD::INTRINSIC_WO_CHAIN); setTargetDAGCombine(ISD::ANY_EXTEND); @@ -7529,6 +7533,134 @@ return SDValue(); } +/// Fold a floating-point multiply by power of two into floating-point to +/// fixed-point conversion. +static SDValue performFpToIntCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (!Subtarget->hasNEON()) + return SDValue(); + + SDValue Op = N->getOperand(0); + if (!Op.getValueType().isVector() || Op.getOpcode() != ISD::FMUL) + return SDValue(); + + SDValue ConstVec = Op->getOperand(1); + if (!isa(ConstVec)) + return SDValue(); + + MVT FloatTy = Op.getSimpleValueType().getVectorElementType(); + uint32_t FloatBits = FloatTy.getSizeInBits(); + if (FloatBits != 32 && FloatBits != 64) + return SDValue(); + + MVT IntTy = N->getSimpleValueType(0).getVectorElementType(); + uint32_t IntBits = IntTy.getSizeInBits(); + if (IntBits != 16 && IntBits != 32 && IntBits != 64) + return SDValue(); + + // Avoid conversions where iN is larger than the float (e.g., float -> i64). + if (IntBits > FloatBits) + return SDValue(); + + BitVector UndefElements; + BuildVectorSDNode *BV = cast(ConstVec); + int32_t Bits = IntBits == 64 ? 64 : 32; + int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, Bits + 1); + if (C == -1 || C == 0 || C > Bits) + return SDValue(); + + MVT ResTy; + unsigned NumLanes = Op.getValueType().getVectorNumElements(); + switch (NumLanes) { + default: + return SDValue(); + case 2: + ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64; + break; + case 4: + ResTy = MVT::v4i32; + break; + } + + SDLoc DL(N); + bool IsSigned = N->getOpcode() == ISD::FP_TO_SINT; + unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfp2fxs + : Intrinsic::aarch64_neon_vcvtfp2fxu; + SDValue FixConv = + DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, ResTy, + DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), + Op->getOperand(0), DAG.getConstant(C, DL, MVT::i32)); + // We can handle smaller integers by generating an extra trunc. + if (IntBits < FloatBits) + FixConv = DAG.getNode(ISD::TRUNCATE, DL, N->getValueType(0), FixConv); + + return FixConv; +} + +/// Fold a floating-point divide by power of two into fixed-point to +/// floating-point conversion. +static SDValue performFDivCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + if (!Subtarget->hasNEON()) + return SDValue(); + + SDValue Op = N->getOperand(0); + unsigned Opc = Op->getOpcode(); + if (!Op.getValueType().isVector() || + (Opc != ISD::SINT_TO_FP && Opc != ISD::UINT_TO_FP)) + return SDValue(); + + SDValue ConstVec = N->getOperand(1); + if (!isa(ConstVec)) + return SDValue(); + + MVT IntTy = Op.getOperand(0).getSimpleValueType().getVectorElementType(); + int32_t IntBits = IntTy.getSizeInBits(); + if (IntBits != 16 && IntBits != 32 && IntBits != 64) + return SDValue(); + + MVT FloatTy = N->getSimpleValueType(0).getVectorElementType(); + int32_t FloatBits = FloatTy.getSizeInBits(); + if (FloatBits != 32 && FloatBits != 64) + return SDValue(); + + // Avoid conversions where iN is larger than the float (e.g., i64 -> float). + if (IntBits > FloatBits) + return SDValue(); + + BitVector UndefElements; + BuildVectorSDNode *BV = cast(ConstVec); + int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, FloatBits + 1); + if (C == -1 || C == 0 || C > FloatBits) + return SDValue(); + + MVT ResTy; + unsigned NumLanes = Op.getValueType().getVectorNumElements(); + switch (NumLanes) { + default: + return SDValue(); + case 2: + ResTy = FloatBits == 32 ? MVT::v2i32 : MVT::v2i64; + break; + case 4: + ResTy = MVT::v4i32; + break; + } + + SDLoc DL(N); + SDValue ConvInput = Op.getOperand(0); + bool IsSigned = Opc == ISD::SINT_TO_FP; + if (IntBits < FloatBits) + ConvInput = DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, + ResTy, ConvInput); + + unsigned IntrinsicOpcode = IsSigned ? Intrinsic::aarch64_neon_vcvtfxs2fp + : Intrinsic::aarch64_neon_vcvtfxu2fp; + return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, Op.getValueType(), + DAG.getConstant(IntrinsicOpcode, DL, MVT::i32), ConvInput, + DAG.getConstant(C, DL, MVT::i32)); +} + /// An EXTR instruction is made up of two shifts, ORed together. This helper /// searches for and classifies those shifts. static bool findEXTRHalf(SDValue N, SDValue &Src, uint32_t &ShiftAmount, @@ -9400,6 +9532,11 @@ case ISD::SINT_TO_FP: case ISD::UINT_TO_FP: return performIntToFpCombine(N, DAG, Subtarget); + case ISD::FP_TO_SINT: + case ISD::FP_TO_UINT: + return performFpToIntCombine(N, DAG, Subtarget); + case ISD::FDIV: + return performFDivCombine(N, DAG, Subtarget); case ISD::OR: return performORCombine(N, DCI, Subtarget); case ISD::INTRINSIC_WO_CHAIN: Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -9800,32 +9800,6 @@ return SDValue(); } -// isConstVecPow2 - Return true if each vector element is a power of 2, all -// elements are the same constant, C, and Log2(C) ranges from 1 to 32. -static bool isConstVecPow2(SDValue ConstVec, bool isSigned, uint64_t &C) -{ - integerPart cN; - integerPart c0 = 0; - for (unsigned I = 0, E = ConstVec.getValueType().getVectorNumElements(); - I != E; I++) { - ConstantFPSDNode *C = dyn_cast(ConstVec.getOperand(I)); - if (!C) - return false; - - bool isExact; - APFloat APF = C->getValueAPF(); - if (APF.convertToInteger(&cN, 64, isSigned, APFloat::rmTowardZero, &isExact) - != APFloat::opOK || !isExact) - return false; - - c0 = (I == 0) ? cN : c0; - if (!isPowerOf2_64(cN) || c0 != cN || Log2_64(c0) < 1 || Log2_64(c0) > 32) - return false; - } - C = c0; - return true; -} - /// PerformVCVTCombine - VCVT (floating-point to fixed-point, Advanced SIMD) /// can replace combinations of VMUL and VCVT (floating-point to integer) /// when the VMUL has a constant operand that is a power of 2. @@ -9861,18 +9835,20 @@ return SDValue(); } - uint64_t C; - bool isSigned = N->getOpcode() == ISD::FP_TO_SINT; - if (!isConstVecPow2(ConstVec, isSigned, C)) + BitVector UndefElements; + BuildVectorSDNode *BV = cast(ConstVec); + int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, 33); + if (C == -1 || C == 0 || C > 32) return SDValue(); SDLoc dl(N); + bool isSigned = N->getOpcode() == ISD::FP_TO_SINT; unsigned IntrinsicOpcode = isSigned ? Intrinsic::arm_neon_vcvtfp2fxs : Intrinsic::arm_neon_vcvtfp2fxu; SDValue FixConv = DAG.getNode( ISD::INTRINSIC_WO_CHAIN, dl, NumLanes == 2 ? MVT::v2i32 : MVT::v4i32, DAG.getConstant(IntrinsicOpcode, dl, MVT::i32), Op->getOperand(0), - DAG.getConstant(Log2_64(C), dl, MVT::i32)); + DAG.getConstant(C, dl, MVT::i32)); if (IntBits < FloatBits) FixConv = DAG.getNode(ISD::TRUNCATE, dl, N->getValueType(0), FixConv); @@ -9915,13 +9891,15 @@ return SDValue(); } - uint64_t C; - bool isSigned = OpOpcode == ISD::SINT_TO_FP; - if (!isConstVecPow2(ConstVec, isSigned, C)) + BitVector UndefElements; + BuildVectorSDNode *BV = cast(ConstVec); + int32_t C = BV->getConstantFPSplatPow2ToLog2Int(&UndefElements, 33); + if (C == -1 || C == 0 || C > 32) return SDValue(); SDLoc dl(N); SDValue ConvInput = Op.getOperand(0); + bool isSigned = OpOpcode == ISD::SINT_TO_FP; unsigned NumLanes = Op.getValueType().getVectorNumElements(); if (IntBits < FloatBits) ConvInput = DAG.getNode(isSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, @@ -9933,7 +9911,7 @@ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, dl, Op.getValueType(), DAG.getConstant(IntrinsicOpcode, dl, MVT::i32), - ConvInput, DAG.getConstant(Log2_64(C), dl, MVT::i32)); + ConvInput, DAG.getConstant(C, dl, MVT::i32)); } /// Getvshiftimm - Check if this is a valid build_vector for the immediate Index: test/CodeGen/AArch64/fcvt_combine.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/fcvt_combine.ll @@ -0,0 +1,154 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s + +; CHECK-LABEL: test1 +; CHECK-NOT: fmul.2s +; CHECK: fcvtzs.2s v0, v0, #4 +; CHECK: ret +define <2 x i32> @test1(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; CHECK-LABEL: test2 +; CHECK-NOT: fmul.4s +; CHECK: fcvtzs.4s v0, v0, #3 +; CHECK: ret +define <4 x i32> @test2(<4 x float> %f) { + %mul.i = fmul <4 x float> %f, + %vcvt.i = fptosi <4 x float> %mul.i to <4 x i32> + ret <4 x i32> %vcvt.i +} + +; CHECK-LABEL: test3 +; CHECK-NOT: fmul.2d +; CHECK: fcvtzs.2d v0, v0, #5 +; CHECK: ret +define <2 x i64> @test3(<2 x double> %d) { + %mul.i = fmul <2 x double> %d, + %vcvt.i = fptosi <2 x double> %mul.i to <2 x i64> + ret <2 x i64> %vcvt.i +} + +; Truncate double to i32 +; CHECK-LABEL: test4 +; CHECK-NOT: fmul.2d v0, v0, #4 +; CHECK: fcvtzs.2d v0, v0 +; CHECK: xtn.2s +; CHECK: ret +define <2 x i32> @test4(<2 x double> %d) { + %mul.i = fmul <2 x double> %d, + %vcvt.i = fptosi <2 x double> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Truncate float to i16 +; CHECK-LABEL: test5 +; CHECK-NOT: fmul.2s +; CHECK: fcvtzs.2s v0, v0, #4 +; CHECK: ret +define <2 x i16> @test5(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i16> + ret <2 x i16> %vcvt.i +} + +; Don't convert float to i64 +; CHECK-LABEL: test6 +; CHECK: fmov.2s v1, #16.00000000 +; CHECK: fmul.2s v0, v0, v1 +; CHECK: fcvtl v0.2d, v0.2s +; CHECK: fcvtzs.2d v0, v0 +; CHECK: ret +define <2 x i64> @test6(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i64> + ret <2 x i64> %vcvt.i +} + +; Check unsigned conversion. +; CHECK-LABEL: test7 +; CHECK-NOT: fmul.2s +; CHECK: fcvtzu.2s v0, v0, #4 +; CHECK: ret +define <2 x i32> @test7(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptoui <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Test which should not fold due to non-power of 2. +; CHECK-LABEL: test8 +; CHECK: fmov.2s v1, #17.00000000 +; CHECK: fmul.2s v0, v0, v1 +; CHECK: fcvtzu.2s v0, v0 +; CHECK: ret +define <2 x i32> @test8(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptoui <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Test which should not fold due to non-matching power of 2. +; CHECK-LABEL: test9 +; CHECK: fmul.2s v0, v0, v1 +; CHECK: fcvtzu.2s v0, v0 +; CHECK: ret +define <2 x i32> @test9(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptoui <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Don't combine all undefs. +; CHECK-LABEL: test10 +; CHECK: fmul.2s v{{[0-9]+}}, v{{[0-9]+}}, v{{[0-9]+}} +; CHECK: fcvtzu.2s v{{[0-9]+}}, v{{[0-9]+}} +; CHECK: ret +define <2 x i32> @test10(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptoui <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Combine if mix of undef and pow2. +; CHECK-LABEL: test11 +; CHECK: fcvtzu.2s v0, v0, #3 +; CHECK: ret +define <2 x i32> @test11(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptoui <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Don't combine when multiplied by 0.0. +; CHECK-LABEL: test12 +; CHECK: fmul.2s v0, v0, v1 +; CHECK: fcvtzs.2s v0, v0 +; CHECK: ret +define <2 x i32> @test12(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Test which should not fold due to power of 2 out of range (i.e., 2^33). +; CHECK-LABEL: test13 +; CHECK: fmul.2s v0, v0, v1 +; CHECK: fcvtzs.2s v0, v0 +; CHECK: ret +define <2 x i32> @test13(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} + +; Test case where const is max power of 2 (i.e., 2^32). +; CHECK-LABEL: test14 +; CHECK: fcvtzs.2s v0, v0, #32 +; CHECK: ret +define <2 x i32> @test14(<2 x float> %f) { + %mul.i = fmul <2 x float> %f, + %vcvt.i = fptosi <2 x float> %mul.i to <2 x i32> + ret <2 x i32> %vcvt.i +} Index: test/CodeGen/AArch64/fdiv_combine.ll =================================================================== --- /dev/null +++ test/CodeGen/AArch64/fdiv_combine.ll @@ -0,0 +1,115 @@ +; RUN: llc -mtriple=aarch64-linux-gnu -aarch64-neon-syntax=apple -verify-machineinstrs -o - %s | FileCheck %s + +; Test signed conversion. +; CHECK-LABEL: @test1 +; CHECK: scvtf.2s v0, v0, #4 +; CHECK: ret +define <2 x float> @test1(<2 x i32> %in) { +entry: + %vcvt.i = sitofp <2 x i32> %in to <2 x float> + %div.i = fdiv <2 x float> %vcvt.i, + ret <2 x float> %div.i +} + +; Test unsigned conversion. +; CHECK-LABEL: @test2 +; CHECK: ucvtf.2s v0, v0, #3 +; CHECK: ret +define <2 x float> @test2(<2 x i32> %in) { +entry: + %vcvt.i = uitofp <2 x i32> %in to <2 x float> + %div.i = fdiv <2 x float> %vcvt.i, + ret <2 x float> %div.i +} + +; Test which should not fold due to non-power of 2. +; CHECK-LABEL: @test3 +; CHECK: scvtf.2s v0, v0 +; CHECK: fmov.2s v1, #9.00000000 +; CHECK: fdiv.2s v0, v0, v1 +; CHECK: ret +define <2 x float> @test3(<2 x i32> %in) { +entry: + %vcvt.i = sitofp <2 x i32> %in to <2 x float> + %div.i = fdiv <2 x float> %vcvt.i, + ret <2 x float> %div.i +} + +; Test which should not fold due to power of 2 out of range. +; CHECK-LABEL: @test4 +; CHECK: scvtf.2s v0, v0 +; CHECK: movi.2s v1, #0x50, lsl #24 +; CHECK: fdiv.2s v0, v0, v1 +; CHECK: ret +define <2 x float> @test4(<2 x i32> %in) { +entry: + %vcvt.i = sitofp <2 x i32> %in to <2 x float> + %div.i = fdiv <2 x float> %vcvt.i, + ret <2 x float> %div.i +} + +; Test case where const is max power of 2 (i.e., 2^32). +; CHECK-LABEL: @test5 +; CHECK: scvtf.2s v0, v0, #32 +; CHECK: ret +define <2 x float> @test5(<2 x i32> %in) { +entry: + %vcvt.i = sitofp <2 x i32> %in to <2 x float> + %div.i = fdiv <2 x float> %vcvt.i, + ret <2 x float> %div.i +} + +; Test quadword. +; CHECK-LABEL: @test6 +; CHECK: scvtf.4s v0, v0, #2 +; CHECK: ret +define <4 x float> @test6(<4 x i32> %in) { +entry: + %vcvt.i = sitofp <4 x i32> %in to <4 x float> + %div.i = fdiv <4 x float> %vcvt.i, + ret <4 x float> %div.i +} + +; Test unsigned i16 to float +; CHECK-LABEL: @test7 +; CHECK: ushll.4s v0, v0, #0 +; CHECK: ucvtf.4s v0, v0, #1 +; CHECK: ret +define <4 x float> @test7(<4 x i16> %in) { + %conv = uitofp <4 x i16> %in to <4 x float> + %shift = fdiv <4 x float> %conv, + ret <4 x float> %shift +} + +; Test signed i16 to float +; CHECK-LABEL: @test8 +; CHECK: sshll.4s v0, v0, #0 +; CHECK: scvtf.4s v0, v0, #2 +; CHECK: ret +define <4 x float> @test8(<4 x i16> %in) { + %conv = sitofp <4 x i16> %in to <4 x float> + %shift = fdiv <4 x float> %conv, + ret <4 x float> %shift +} + +; Can't convert i64 to float. +; CHECK-LABEL: @test9 +; CHECK: ucvtf.2d v0, v0 +; CHECK: fcvtn v0.2s, v0.2d +; CHECK: movi.2s v1, #0x40, lsl #24 +; CHECK: fdiv.2s v0, v0, v1 +; CHECK: ret +define <2 x float> @test9(<2 x i64> %in) { + %conv = uitofp <2 x i64> %in to <2 x float> + %shift = fdiv <2 x float> %conv, + ret <2 x float> %shift +} + +; CHECK-LABEL: @test10 +; CHECK: ucvtf.2d v0, v0, #1 +; CHECK: ret +define <2 x double> @test10(<2 x i64> %in) { + %conv = uitofp <2 x i64> %in to <2 x double> + %shift = fdiv <2 x double> %conv, + ret <2 x double> %shift +}