diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -5768,14 +5768,11 @@ EVT EltVT = VecVT.getVectorElementType(); unsigned VecSize = VecVT.getSizeInBits(); unsigned EltSize = EltVT.getSizeInBits(); + SDLoc SL(Op); - - assert(VecSize <= 64); - + // Specially handle the case of v4i16 with static indexing. unsigned NumElts = VecVT.getVectorNumElements(); - SDLoc SL(Op); auto KIdx = dyn_cast(Idx); - if (NumElts == 4 && EltSize == 16 && KIdx) { SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, MVT::v2i32, Vec); @@ -5803,35 +5800,46 @@ return DAG.getNode(ISD::BITCAST, SL, VecVT, Concat); } + // Static indexing does not lower to stack access, and hence there is no need + // for special custom lowering to avoid stack access. if (isa(Idx)) return SDValue(); - MVT IntVT = MVT::getIntegerVT(VecSize); - - // Avoid stack access for dynamic indexing. + // Avoid stack access for dynamic indexing by custom lowering to // v_bfi_b32 (v_bfm_b32 16, (shl idx, 16)), val, vec + // + // TODO: However, we can only handle vector size of 64 bits for now. + assert(VecSize <= 64); - // Create a congruent vector with the target value in each element so that - // the required element can be masked and ORed into the target vector. - SDValue ExtVal = DAG.getNode(ISD::BITCAST, SL, IntVT, - DAG.getSplatBuildVector(VecVT, SL, InsVal)); + // 1. Create a congruent vector with the target value in each element. + // 2. Mask off all other indicies except the required index within (1). + // 3. Mask off the required index within the target vector. + // 4. Get (2) and (3) ORed into the target vector. + MVT IntVT = MVT::getIntegerVT(VecSize); + // Convert vector index to bit-index and get the required bit mask. assert(isPowerOf2_32(EltSize)); SDValue ScaleFactor = DAG.getConstant(Log2_32(EltSize), SL, MVT::i32); - - // Convert vector index to bit-index. SDValue ScaledIdx = DAG.getNode(ISD::SHL, SL, MVT::i32, Idx, ScaleFactor); - - SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, IntVT, Vec); SDValue BFM = DAG.getNode(ISD::SHL, SL, IntVT, DAG.getConstant(0xffff, SL, IntVT), ScaledIdx); + // Perform (1) above. + SDValue ExtVal = DAG.getNode(ISD::BITCAST, SL, IntVT, + DAG.getSplatBuildVector(VecVT, SL, InsVal)); + + // Perform (2) above. SDValue LHS = DAG.getNode(ISD::AND, SL, IntVT, BFM, ExtVal); + + // Perform (3) above. + SDValue BCVec = DAG.getNode(ISD::BITCAST, SL, IntVT, Vec); SDValue RHS = DAG.getNode(ISD::AND, SL, IntVT, DAG.getNOT(SL, BFM, IntVT), BCVec); + // Perform (4) above. SDValue BFI = DAG.getNode(ISD::OR, SL, IntVT, LHS, RHS); + return DAG.getNode(ISD::BITCAST, SL, VecVT, BFI); }