diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -905,6 +905,7 @@ setTargetDAGCombine(ISD::INTRINSIC_VOID); setTargetDAGCombine(ISD::INTRINSIC_W_CHAIN); + setTargetDAGCombine(ISD::INSERT_SUBVECTOR); setTargetDAGCombine(ISD::INSERT_VECTOR_ELT); setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT); setTargetDAGCombine(ISD::VECREDUCE_ADD); @@ -14676,8 +14677,30 @@ if (VT.isScalableVector()) return SDValue(); - unsigned LoadIdx = IsLaneOp ? 1 : 0; - SDNode *LD = N->getOperand(LoadIdx).getNode(); + SDNode *LD = nullptr; + SDNode *LDUser = N; + // If N is an INSERT_SUBVECTOR, its index operand must be multiplied + // to get the index for LD1. + unsigned LaneScale = 1; + + if (N->getOpcode() == ISD::INSERT_SUBVECTOR) { + // Handle (insert_subvector V (scalar_to_vector (load)) Idx) + SDValue ScalarToVec = N->getOperand(1); + if (ScalarToVec->getOpcode() != ISD::SCALAR_TO_VECTOR) + return SDValue(); + LD = ScalarToVec->getOperand(0).getNode(); + LDUser = ScalarToVec.getNode(); + + // Adjust the lane index so that (insert_subvector (... Addr) Idx) + // matches (LD1LANE Addr NewIdx). + EVT SubvecVT = ScalarToVec->getValueType(0); + LaneScale *= VT.getVectorNumElements() / SubvecVT.getVectorNumElements(); + assert(LaneScale && "Invalid insert_subvector/scalar_to_vector sequence"); + } else { + unsigned LoadIdx = IsLaneOp ? 1 : 0; + LD = N->getOperand(LoadIdx).getNode(); + } + // If it is not LOAD, can not do such combine. if (LD->getOpcode() != ISD::LOAD) return SDValue(); @@ -14687,15 +14710,34 @@ if (IsLaneOp) { Lane = N->getOperand(2); auto *LaneC = dyn_cast(Lane); - if (!LaneC || LaneC->getZExtValue() >= VT.getVectorNumElements()) + if (!LaneC) + return SDValue(); + + uint64_t LaneVal = LaneC->getZExtValue(); + if (LaneVal >= VT.getVectorNumElements()) return SDValue(); + + if (LaneScale != 1) { + Lane = DAG.getVectorIdxConstant(LaneVal * LaneScale, SDLoc(LaneC)); + } } LoadSDNode *LoadSDN = cast(LD); EVT MemVT = LoadSDN->getMemoryVT(); + EVT ResultVT = VT; // Check if memory operand is the same type as the vector element. - if (MemVT != VT.getVectorElementType()) - return SDValue(); + if (MemVT != VT.getVectorElementType()) { + if (LoadSDN->getExtensionType() != ISD::EXTLOAD) + return SDValue(); + + // For EXTLOAD we need to adjust load type and lane index to match + // the memory type. + LaneScale *= VT.getScalarSizeInBits() / MemVT.getScalarSizeInBits(); + assert(LaneScale && "Invalid load anyext"); + ResultVT = + EVT::getVectorVT(*DAG.getContext(), MemVT, + VT.getFixedSizeInBits() / MemVT.getScalarSizeInBits()); + } // Check if there are other uses. If so, do not combine as it will introduce // an extra load. @@ -14703,7 +14745,7 @@ ++UI) { if (UI.getUse().getResNo() == 1) // Ignore uses of the chain result. continue; - if (*UI != N) + if (*UI != LDUser) return SDValue(); } @@ -14721,7 +14763,7 @@ SDValue Inc = User->getOperand(User->getOperand(0) == Addr ? 1 : 0); if (ConstantSDNode *CInc = dyn_cast(Inc.getNode())) { uint32_t IncVal = CInc->getZExtValue(); - unsigned NumBytes = VT.getScalarSizeInBits() / 8; + unsigned NumBytes = MemVT.getScalarSizeInBits() / 8; if (IncVal != NumBytes) continue; Inc = DAG.getRegister(AArch64::XZR, MVT::i64); @@ -14748,20 +14790,24 @@ Ops.push_back(Addr); Ops.push_back(Inc); - EVT Tys[3] = { VT, MVT::i64, MVT::Other }; + EVT Tys[3] = {ResultVT, MVT::i64, MVT::Other}; SDVTList SDTys = DAG.getVTList(Tys); unsigned NewOp = IsLaneOp ? AArch64ISD::LD1LANEpost : AArch64ISD::LD1DUPpost; SDValue UpdN = DAG.getMemIntrinsicNode(NewOp, SDLoc(N), SDTys, Ops, MemVT, LoadSDN->getMemOperand()); + SDValue Result = (ResultVT == VT) + ? SDValue(UpdN.getNode(), 0) + : DAG.getNode(ISD::BITCAST, SDLoc(N), VT, UpdN); + // Update the uses. SDValue NewResults[] = { SDValue(LD, 0), // The result of load SDValue(UpdN.getNode(), 2) // Chain }; DCI.CombineTo(LD, NewResults); - DCI.CombineTo(N, SDValue(UpdN.getNode(), 0)); // Dup/Inserted Result + DCI.CombineTo(N, Result); // Dup/Inserted Result DCI.CombineTo(User, SDValue(UpdN.getNode(), 1)); // Write back register break; @@ -16143,6 +16189,8 @@ return performVectorShiftCombine(N, *this, DCI); case ISD::INSERT_VECTOR_ELT: return performInsertVectorEltCombine(N, DCI); + case ISD::INSERT_SUBVECTOR: + return performPostLD1Combine(N, DCI, true); case ISD::EXTRACT_VECTOR_ELT: return performExtractVectorEltCombine(N, DAG); case ISD::VECREDUCE_ADD: diff --git a/llvm/test/CodeGen/AArch64/aarch64-load-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-load-ext.ll --- a/llvm/test/CodeGen/AArch64/aarch64-load-ext.ll +++ b/llvm/test/CodeGen/AArch64/aarch64-load-ext.ll @@ -15,9 +15,8 @@ define <2 x i16> @test1(<2 x i16>* %v2i16_ptr) { ; CHECK-LABEL: test1: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.h }[0], [x0] -; CHECK-NEXT: add x8, x0, #2 // =2 -; CHECK-NEXT: ld1 { v0.h }[2], [x8] +; CHECK-NEXT: ld1 { v0.h }[0], [x0], #2 +; CHECK-NEXT: ld1 { v0.h }[2], [x0] ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %v2i16 = load <2 x i16>, <2 x i16>* %v2i16_ptr @@ -27,9 +26,9 @@ define <2 x i16> @test2(i16* %i16_ptr, i64 %inc) { ; CHECK-LABEL: test2: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.h }[0], [x0] -; CHECK-NEXT: add x8, x0, x1, lsl #1 -; CHECK-NEXT: ld1 { v0.h }[2], [x8] +; CHECK-NEXT: lsl x8, x1, #1 +; CHECK-NEXT: ld1 { v0.h }[0], [x0], x8 +; CHECK-NEXT: ld1 { v0.h }[2], [x0] ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %i_0 = load i16, i16* %i16_ptr @@ -43,9 +42,8 @@ define <2 x i8> @test3(<2 x i8>* %v2i8_ptr) { ; CHECK-LABEL: test3: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.b }[0], [x0] -; CHECK-NEXT: add x8, x0, #1 // =1 -; CHECK-NEXT: ld1 { v0.b }[4], [x8] +; CHECK-NEXT: ld1 { v0.b }[0], [x0], #1 +; CHECK-NEXT: ld1 { v0.b }[4], [x0] ; CHECK-NEXT: // kill: def $d0 killed $d0 killed $q0 ; CHECK-NEXT: ret %v2i8 = load <2 x i8>, <2 x i8>* %v2i8_ptr @@ -55,8 +53,8 @@ define <4 x i8> @test4(<4 x i8>* %v4i8_ptr) { ; CHECK-LABEL: test4: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.b }[0], [x0] -; CHECK-NEXT: add x8, x0, #1 // =1 +; CHECK-NEXT: mov x8, x0 +; CHECK-NEXT: ld1 { v0.b }[0], [x8], #1 ; CHECK-NEXT: ld1 { v0.b }[2], [x8] ; CHECK-NEXT: add x8, x0, #2 // =2 ; CHECK-NEXT: ld1 { v0.b }[4], [x8] diff --git a/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll b/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll --- a/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll +++ b/llvm/test/CodeGen/AArch64/sadd_sat_vec.ll @@ -145,15 +145,13 @@ define void @v2i8(<2 x i8>* %px, <2 x i8>* %py, <2 x i8>* %pz) nounwind { ; CHECK-LABEL: v2i8: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.b }[0], [x1] -; CHECK-NEXT: ld1 { v1.b }[0], [x0] -; CHECK-NEXT: add x8, x0, #1 // =1 -; CHECK-NEXT: add x9, x1, #1 // =1 -; CHECK-NEXT: ld1 { v0.b }[4], [x9] -; CHECK-NEXT: ld1 { v1.b }[4], [x8] +; CHECK-NEXT: ld1 { v0.b }[0], [x0], #1 +; CHECK-NEXT: ld1 { v0.b }[4], [x0] +; CHECK-NEXT: ld1 { v1.b }[0], [x1], #1 ; CHECK-NEXT: shl v0.2s, v0.2s, #24 +; CHECK-NEXT: ld1 { v1.b }[4], [x1] ; CHECK-NEXT: shl v1.2s, v1.2s, #24 -; CHECK-NEXT: sqadd v0.2s, v1.2s, v0.2s +; CHECK-NEXT: sqadd v0.2s, v0.2s, v1.2s ; CHECK-NEXT: ushr v0.2s, v0.2s, #24 ; CHECK-NEXT: mov w8, v0.s[1] ; CHECK-NEXT: fmov w9, s0 @@ -185,15 +183,13 @@ define void @v2i16(<2 x i16>* %px, <2 x i16>* %py, <2 x i16>* %pz) nounwind { ; CHECK-LABEL: v2i16: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.h }[0], [x1] -; CHECK-NEXT: ld1 { v1.h }[0], [x0] -; CHECK-NEXT: add x8, x0, #2 // =2 -; CHECK-NEXT: add x9, x1, #2 // =2 -; CHECK-NEXT: ld1 { v0.h }[2], [x9] -; CHECK-NEXT: ld1 { v1.h }[2], [x8] +; CHECK-NEXT: ld1 { v0.h }[0], [x0], #2 +; CHECK-NEXT: ld1 { v0.h }[2], [x0] +; CHECK-NEXT: ld1 { v1.h }[0], [x1], #2 ; CHECK-NEXT: shl v0.2s, v0.2s, #16 +; CHECK-NEXT: ld1 { v1.h }[2], [x1] ; CHECK-NEXT: shl v1.2s, v1.2s, #16 -; CHECK-NEXT: sqadd v0.2s, v1.2s, v0.2s +; CHECK-NEXT: sqadd v0.2s, v0.2s, v1.2s ; CHECK-NEXT: ushr v0.2s, v0.2s, #16 ; CHECK-NEXT: mov w8, v0.s[1] ; CHECK-NEXT: fmov w9, s0 diff --git a/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll b/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll --- a/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll +++ b/llvm/test/CodeGen/AArch64/ssub_sat_vec.ll @@ -146,15 +146,13 @@ define void @v2i8(<2 x i8>* %px, <2 x i8>* %py, <2 x i8>* %pz) nounwind { ; CHECK-LABEL: v2i8: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.b }[0], [x1] -; CHECK-NEXT: ld1 { v1.b }[0], [x0] -; CHECK-NEXT: add x8, x0, #1 // =1 -; CHECK-NEXT: add x9, x1, #1 // =1 -; CHECK-NEXT: ld1 { v0.b }[4], [x9] -; CHECK-NEXT: ld1 { v1.b }[4], [x8] +; CHECK-NEXT: ld1 { v0.b }[0], [x0], #1 +; CHECK-NEXT: ld1 { v0.b }[4], [x0] +; CHECK-NEXT: ld1 { v1.b }[0], [x1], #1 ; CHECK-NEXT: shl v0.2s, v0.2s, #24 +; CHECK-NEXT: ld1 { v1.b }[4], [x1] ; CHECK-NEXT: shl v1.2s, v1.2s, #24 -; CHECK-NEXT: sqsub v0.2s, v1.2s, v0.2s +; CHECK-NEXT: sqsub v0.2s, v0.2s, v1.2s ; CHECK-NEXT: ushr v0.2s, v0.2s, #24 ; CHECK-NEXT: mov w8, v0.s[1] ; CHECK-NEXT: fmov w9, s0 @@ -186,15 +184,13 @@ define void @v2i16(<2 x i16>* %px, <2 x i16>* %py, <2 x i16>* %pz) nounwind { ; CHECK-LABEL: v2i16: ; CHECK: // %bb.0: -; CHECK-NEXT: ld1 { v0.h }[0], [x1] -; CHECK-NEXT: ld1 { v1.h }[0], [x0] -; CHECK-NEXT: add x8, x0, #2 // =2 -; CHECK-NEXT: add x9, x1, #2 // =2 -; CHECK-NEXT: ld1 { v0.h }[2], [x9] -; CHECK-NEXT: ld1 { v1.h }[2], [x8] +; CHECK-NEXT: ld1 { v0.h }[0], [x0], #2 +; CHECK-NEXT: ld1 { v0.h }[2], [x0] +; CHECK-NEXT: ld1 { v1.h }[0], [x1], #2 ; CHECK-NEXT: shl v0.2s, v0.2s, #16 +; CHECK-NEXT: ld1 { v1.h }[2], [x1] ; CHECK-NEXT: shl v1.2s, v1.2s, #16 -; CHECK-NEXT: sqsub v0.2s, v1.2s, v0.2s +; CHECK-NEXT: sqsub v0.2s, v0.2s, v1.2s ; CHECK-NEXT: ushr v0.2s, v0.2s, #16 ; CHECK-NEXT: mov w8, v0.s[1] ; CHECK-NEXT: fmov w9, s0