diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10794,16 +10794,19 @@ SelectionDAG &DAG = DCI.DAG; EVT VecVT = Vec.getValueType(); - EVT EltVT = VecVT.getVectorElementType(); + EVT VecEltVT = VecVT.getVectorElementType(); + EVT ResVT = N->getValueType(0); + + unsigned VecSize = VecVT.getSizeInBits(); + unsigned VecEltSize = VecEltVT.getSizeInBits(); if ((Vec.getOpcode() == ISD::FNEG || Vec.getOpcode() == ISD::FABS) && allUsesHaveSourceMods(N)) { SDLoc SL(N); - EVT EltVT = N->getValueType(0); SDValue Idx = N->getOperand(1); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, - Vec.getOperand(0), Idx); - return DAG.getNode(Vec.getOpcode(), SL, EltVT, Elt); + SDValue Elt = + DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec.getOperand(0), Idx); + return DAG.getNode(Vec.getOpcode(), SL, ResVT, Elt); } // ScalarRes = EXTRACT_VECTOR_ELT ((vector-BINOP Vec1, Vec2), Idx) @@ -10813,41 +10816,51 @@ // ScalarRes = scalar-BINOP Vec1Elt, Vec2Elt if (Vec.hasOneUse() && DCI.isBeforeLegalize()) { SDLoc SL(N); - EVT EltVT = N->getValueType(0); SDValue Idx = N->getOperand(1); unsigned Opc = Vec.getOpcode(); + bool ZeroExtOnTypeMismatch = true; switch(Opc) { default: break; // TODO: Support other binary operations. + case ISD::SMIN: + case ISD::SMAX: + ZeroExtOnTypeMismatch = false; + LLVM_FALLTHROUGH; case ISD::FADD: case ISD::FSUB: case ISD::FMUL: case ISD::ADD: case ISD::UMIN: case ISD::UMAX: - case ISD::SMIN: - case ISD::SMAX: case ISD::FMAXNUM: case ISD::FMINNUM: case ISD::FMAXNUM_IEEE: case ISD::FMINNUM_IEEE: { - SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, + EVT ExtractEltTy = ResVT.getSizeInBits() <= VecEltSize ? ResVT : VecEltVT; + SDValue Elt0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ExtractEltTy, Vec.getOperand(0), Idx); - SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, + SDValue Elt1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ExtractEltTy, Vec.getOperand(1), Idx); DCI.AddToWorklist(Elt0.getNode()); DCI.AddToWorklist(Elt1.getNode()); - return DAG.getNode(Opc, SL, EltVT, Elt0, Elt1, Vec->getFlags()); + + if (ExtractEltTy != ResVT) { + auto CastOpc = + ZeroExtOnTypeMismatch ? ISD::ZERO_EXTEND : ISD::SIGN_EXTEND; + Elt0 = DAG.getNode(CastOpc, SL, ResVT, Elt0); + Elt1 = DAG.getNode(CastOpc, SL, ResVT, Elt1); + DCI.AddToWorklist(Elt0.getNode()); + DCI.AddToWorklist(Elt1.getNode()); + } + + return DAG.getNode(Opc, SL, ResVT, Elt0, Elt1, Vec->getFlags()); } } } - unsigned VecSize = VecVT.getSizeInBits(); - unsigned EltSize = EltVT.getSizeInBits(); - // EXTRACT_VECTOR_ELT (, var-idx) => n x select (e, const-idx) if (shouldExpandVectorDynExt(N)) { SDLoc SL(N); @@ -10855,7 +10868,7 @@ SDValue V; for (unsigned I = 0, E = VecVT.getVectorNumElements(); I < E; ++I) { SDValue IC = DAG.getVectorIdxConstant(I, SL); - SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, EltVT, Vec, IC); + SDValue Elt = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, SL, ResVT, Vec, IC); if (I == 0) V = Elt; else @@ -10871,15 +10884,11 @@ // elements. This exposes more load reduction opportunities by replacing // multiple small extract_vector_elements with a single 32-bit extract. auto *Idx = dyn_cast(N->getOperand(1)); - if (isa(Vec) && - EltSize <= 16 && - EltVT.isByteSized() && - VecSize > 32 && - VecSize % 32 == 0 && - Idx) { + if (isa(Vec) && VecEltSize <= 16 && VecEltVT.isByteSized() && + VecSize > 32 && VecSize % 32 == 0 && Idx) { EVT NewVT = getEquivalentMemType(*DAG.getContext(), VecVT); - unsigned BitIndex = Idx->getZExtValue() * EltSize; + unsigned BitIndex = Idx->getZExtValue() * VecEltSize; unsigned EltIdx = BitIndex / 32; unsigned LeftoverBitIdx = BitIndex % 32; SDLoc SL(N); @@ -10894,9 +10903,16 @@ DAG.getConstant(LeftoverBitIdx, SL, MVT::i32)); DCI.AddToWorklist(Srl.getNode()); - SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, EltVT.changeTypeToInteger(), Srl); + EVT VecEltAsIntVT = VecEltVT.changeTypeToInteger(); + SDValue Trunc = DAG.getNode(ISD::TRUNCATE, SL, VecEltAsIntVT, Srl); DCI.AddToWorklist(Trunc.getNode()); - return DAG.getNode(ISD::BITCAST, SL, EltVT, Trunc); + + if (VecEltVT == ResVT) { + return DAG.getNode(ISD::BITCAST, SL, VecEltVT, Trunc); + } + + assert(ResVT.isScalarInteger()); + return DAG.getAnyExtOrTrunc(Trunc, SL, ResVT); } return SDValue();