diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -912,6 +912,8 @@ setTargetDAGCombine(ISD::MUL); + setTargetDAGCombine(ISD::FP_EXTEND); + setTargetDAGCombine(ISD::SELECT); setTargetDAGCombine(ISD::VSELECT); @@ -15237,6 +15239,100 @@ return SDValue(); } +static SDValue performFpExtendCombine(SDNode *N, SelectionDAG &DAG, + const AArch64Subtarget *Subtarget) { + SDLoc DL(N); + SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + + if (!VT.isFixedLengthVector()) + return SDValue(); + + if (DAG.getTargetLoweringInfo().isTypeLegal(VT) || + !Subtarget->useSVEForFixedLengthVectors()) + return SDValue(); + + // In cases where the result of the fp_extend is not legal, it will be + // expanded into multiple extract_subvectors which cannot be lowered without + // going through memory. + // + // If we push an extend into the load feeding the fp_extend, we can force the + // load to be be expanded into the same number of parts as the fp_extend, + // avoiding the need for extract_subvectors completely. + // + // As part of the lowering of FP_EXTEND for fixed length types uunpklo nodes + // will be introduced which will then combine with the truncate introduced + // after the load. + if (ISD::isNormalLoad(Op.getNode())) { + LoadSDNode *LD = cast(Op.getNode()); + + // Check if there are other uses. If so, do not combine as it will introduce + // an extra load. + for (SDNode::use_iterator UI = LD->use_begin(), UE = LD->use_end(); + UI != UE; ++UI) { + if (UI.getUse().getResNo() == 1) // Ignore uses of the chain result. + continue; + if (*UI != N) + return SDValue(); + } + + SDValue NewLoad = DAG.getExtLoad( + ISD::ZEXTLOAD, DL, VT.changeTypeToInteger(), LD->getChain(), + LD->getBasePtr(), LD->getMemoryVT().changeTypeToInteger(), + LD->getMemOperand()); + + DAG.ReplaceAllUsesOfValueWith(SDValue(LD, 1), NewLoad.getValue(1)); + + SDValue Trunc = DAG.getNode( + ISD::TRUNCATE, DL, Op->getValueType(0).changeTypeToInteger(), NewLoad); + SDValue Bitcast = DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), Trunc); + + return DAG.getNode(ISD::FP_EXTEND, DL, VT, Bitcast); + } + + return SDValue(); +} + +static SDValue performUunpkloCombine(SDNode *N, SelectionDAG &DAG) { + SDLoc DL(N); + SDValue Op = N->getOperand(0); + EVT VT = N->getValueType(0); + + // uunpklo(uzp1(x, x)) where x = bitcast(zextload) -> x + if (Op->getOpcode() == AArch64ISD::UZP1 && + Op->getOperand(0) == Op->getOperand(1)) { + EVT HalfVT = Op.getValueType(); + + // Ensure the unzip input is the same size as the unpack output + if (Op->getOperand(0)->getOpcode() != ISD::BITCAST || + Op->getValueType(0) == VT) + return SDValue(); + + SDValue Bitcast = Op->getOperand(0); + + // Look through bitcasts and unzips + SDValue Input = Bitcast->getOperand(0); + while (Input->getOpcode() == ISD::BITCAST || + (Input->getOpcode() == AArch64ISD::UZP1 && + Input->getOperand(0) == Input->getOperand(1))) + Input = Input->getOperand(0); + + // Input should come from an extending load + if (!isa(Input) || + cast(Input)->getExtensionType() != ISD::ZEXTLOAD) + return SDValue(); + + // Ensure that we don't care about the top half of the input + EVT MemVT = cast(Input)->getMemoryVT(); + if (isPackedVectorType(MemVT, DAG) && + MemVT.getVectorElementType().getScalarSizeInBits() <= + HalfVT.getScalarSizeInBits()) + return Bitcast->getOperand(0); + } + + return SDValue(); +} + static SDValue performGLD1Combine(SDNode *N, SelectionDAG &DAG) { unsigned Opc = N->getOpcode(); @@ -16882,6 +16978,8 @@ return performUzpCombine(N, DAG); case AArch64ISD::SETCC_MERGE_ZERO: return performSetccMergeZeroCombine(N, DAG); + case ISD::FP_EXTEND: + return performFpExtendCombine(N, DAG, Subtarget); case AArch64ISD::GLD1_MERGE_ZERO: case AArch64ISD::GLD1_SCALED_MERGE_ZERO: case AArch64ISD::GLD1_UXTW_MERGE_ZERO: @@ -16900,6 +16998,8 @@ case AArch64ISD::VASHR: case AArch64ISD::VLSHR: return performVectorShiftCombine(N, *this, DCI); + case AArch64ISD::UUNPKLO: + return performUunpkloCombine(N, DAG); case ISD::INSERT_VECTOR_ELT: return performInsertVectorEltCombine(N, DCI); case ISD::EXTRACT_VECTOR_ELT: diff --git a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-extend-trunc.ll b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-extend-trunc.ll --- a/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-extend-trunc.ll +++ b/llvm/test/CodeGen/AArch64/sve-fixed-length-fp-extend-trunc.ll @@ -66,19 +66,15 @@ ; VBITS_GE_512-NEXT: st1w { [[RES]].s }, [[PG1]], [x1] ; VBITS_GE_512-NEXT: ret -; Ensure sensible type legalisation - fixed type extract_subvector codegen is poor currently. -; VBITS_EQ_256-DAG: ptrue [[PG1:p[0-9]+]].h, vl16 -; VBITS_EQ_256-DAG: ld1h { [[VEC:z[0-9]+]].h }, [[PG1]]/z, [x0] -; VBITS_EQ_256-DAG: st1h { [[VEC:z[0-9]+]].h }, [[PG1]], [x8] -; VBITS_EQ_256-DAG: ldp q[[LO:[0-9]+]], q[[HI:[0-9]+]], [sp] -; VBITS_EQ_256-DAG: ptrue [[PG2:p[0-9]+]].s, vl8 +; Ensure sensible type legalisation ; VBITS_EQ_256-DAG: mov x[[NUMELTS:[0-9]+]], #8 -; VBITS_EQ_256-DAG: uunpklo [[UPK_LO:z[0-9]+]].s, z[[LO]].h -; VBITS_EQ_256-DAG: uunpklo [[UPK_HI:z[0-9]+]].s, z[[HI]].h -; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].s, [[PG2]]/m, [[UPK_LO]].h -; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].s, [[PG2]]/m, [[UPK_HI]].h -; VBITS_EQ_256-DAG: st1w { [[RES_LO]].s }, [[PG2]], [x1] -; VBITS_EQ_256-DAG: st1w { [[RES_HI]].s }, [[PG2]], [x1, x[[NUMELTS]], lsl #2] +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].s, vl8 +; VBITS_EQ_256-DAG: ld1h { [[VEC_LO:z[0-9]+]].s }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: ld1h { [[VEC_HI:z[0-9]+]].s }, [[PG]]/z, [x0, x[[NUMELTS]], lsl #1] +; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].s, [[PG]]/m, [[VEC_LO]].h +; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].s, [[PG]]/m, [[VEC_HI]].h +; VBITS_EQ_256-DAG: st1w { [[RES_LO]].s }, [[PG]], [x1] +; VBITS_EQ_256-DAG: st1w { [[RES_HI]].s }, [[PG]], [x1, x[[NUMELTS]], lsl #2] %op1 = load <16 x half>, <16 x half>* %a %res = fpext <16 x half> %op1 to <16 x float> store <16 x float> %res, <16 x float>* %b @@ -166,16 +162,12 @@ ; VBITS_GE_512-NEXT: ret ; Ensure sensible type legalisation. -; VBITS_EQ_256-DAG: ldr q[[OP:[0-9]+]], [x0] -; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].d, vl4 ; VBITS_EQ_256-DAG: mov x[[NUMELTS:[0-9]+]], #4 -; VBITS_EQ_256-DAG: ext v[[HI:[0-9]+]].16b, v[[OP]].16b, v[[OP]].16b, #8 -; VBITS_EQ_256-DAG: uunpklo [[UPK1_LO:z[0-9]+]].s, z[[OP]].h -; VBITS_EQ_256-DAG: uunpklo [[UPK1_HI:z[0-9]+]].s, z[[HI]].h -; VBITS_EQ_256-DAG: uunpklo [[UPK2_LO:z[0-9]+]].d, [[UPK1_LO]].s -; VBITS_EQ_256-DAG: uunpklo [[UPK2_HI:z[0-9]+]].d, [[UPK2_HI]].s -; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].d, [[PG]]/m, [[UPK2_LO]].h -; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].d, [[PG]]/m, [[UPK2_HI]].h +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].d, vl4 +; VBITS_EQ_254-DAG: ld1h { [[VEC_HI:z[0-9]+]].d }, [[PG]]/z, [x0, x[[NUMELTS]], lsl #1] +; VBITS_EQ_254-DAG: ld1h { [[VEC_LO:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].d, [[PG]]/m, [[VEC_LO]].h +; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].d, [[PG]]/m, [[VEC_HI]].h ; VBITS_EQ_256-DAG: st1d { [[RES_LO]].d }, [[PG]], [x1] ; VBITS_EQ_256-DAG: st1d { [[RES_HI]].d }, [[PG]], [x1, x[[NUMELTS]], lsl #3] %op1 = load <8 x half>, <8 x half>* %a @@ -262,19 +254,15 @@ ; VBITS_GE_512-NEXT: st1d { [[RES]].d }, [[PG1]], [x1] ; VBITS_GE_512-NEXT: ret -; Ensure sensible type legalisation - fixed type extract_subvector codegen is poor currently. -; VBITS_EQ_256-DAG: ptrue [[PG1:p[0-9]+]].s, vl8 -; VBITS_EQ_256-DAG: ld1w { [[VEC:z[0-9]+]].s }, [[PG1]]/z, [x0] -; VBITS_EQ_256-DAG: st1w { [[VEC:z[0-9]+]].s }, [[PG1]], [x8] -; VBITS_EQ_256-DAG: ldp q[[LO:[0-9]+]], q[[HI:[0-9]+]], [sp] -; VBITS_EQ_256-DAG: ptrue [[PG2:p[0-9]+]].d, vl4 +; Ensure sensible type legalisation ; VBITS_EQ_256-DAG: mov x[[NUMELTS:[0-9]+]], #4 -; VBITS_EQ_256-DAG: uunpklo [[UPK_LO:z[0-9]+]].d, z[[LO]].s -; VBITS_EQ_256-DAG: uunpklo [[UPK_HI:z[0-9]+]].d, z[[HI]].s -; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].d, [[PG2]]/m, [[UPK_LO]].s -; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].d, [[PG2]]/m, [[UPK_HI]].s -; VBITS_EQ_256-DAG: st1d { [[RES_LO]].d }, [[PG2]], [x1] -; VBITS_EQ_256-DAG: st1d { [[RES_HI]].d }, [[PG2]], [x1, x[[NUMELTS]], lsl #3] +; VBITS_EQ_256-DAG: ptrue [[PG:p[0-9]+]].d, vl4 +; VBITS_EQ_256-DAG: ld1w { [[VEC_HI:z[0-9]+]].d }, [[PG]]/z, [x0, x[[NUMELTS]], lsl #2] +; VBITS_EQ_256-DAG: ld1w { [[VEC_LO:z[0-9]+]].d }, [[PG]]/z, [x0] +; VBITS_EQ_256-DAG: fcvt [[RES_LO:z[0-9]+]].d, [[PG]]/m, [[VEC_LO]].s +; VBITS_EQ_256-DAG: fcvt [[RES_HI:z[0-9]+]].d, [[PG]]/m, [[VEC_HI]].s +; VBITS_EQ_256-DAG: st1d { [[RES_LO]].d }, [[PG]], [x1] +; VBITS_EQ_256-DAG: st1d { [[RES_HI]].d }, [[PG]], [x1, x[[NUMELTS]], lsl #3] %op1 = load <8 x float>, <8 x float>* %a %res = fpext <8 x float> %op1 to <8 x double> store <8 x double> %res, <8 x double>* %b