Index: llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -10668,23 +10668,33 @@ } bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled, - SelectionDAG &DAG) { - if (!isNullConstant(BasePtr) || Index.getOpcode() != ISD::ADD) + SelectionDAG &DAG, const SDLoc &DL) { + if (Index.getOpcode() != ISD::ADD) return false; // Only perform the transformation when existing operands can be reused. if (IndexIsScaled) return false; + if (!isNullConstant(BasePtr) && !Index.hasOneUse()) + return false; + + EVT VT = BasePtr.getValueType(); if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(0)); - SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) { - BasePtr = SplatVal; + SplatVal && SplatVal.getValueType() == VT) { + if (isNullConstant(BasePtr)) + BasePtr = SplatVal; + else + BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal); Index = Index.getOperand(1); return true; } if (SDValue SplatVal = DAG.getSplatValue(Index.getOperand(1)); - SplatVal && SplatVal.getValueType() == BasePtr.getValueType()) { - BasePtr = SplatVal; + SplatVal && SplatVal.getValueType() == VT) { + if (isNullConstant(BasePtr)) + BasePtr = SplatVal; + else + BasePtr = DAG.getNode(ISD::ADD, DL, VT, BasePtr, SplatVal); Index = Index.getOperand(0); return true; } @@ -10769,7 +10779,7 @@ if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return Chain; - if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale}; return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops, MSC->getMemOperand(), IndexType, @@ -10893,7 +10903,7 @@ if (ISD::isConstantSplatVectorAllZeros(Mask.getNode())) return CombineTo(N, PassThru, MGT->getChain()); - if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) { + if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG, DL)) { SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale}; return DAG.getMaskedGather( DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL, Index: llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll =================================================================== --- llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll +++ llvm/test/CodeGen/AArch64/sve-gather-scatter-addr-opts.ll @@ -105,18 +105,16 @@ ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: mov w9, #67108864 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: add x10, x0, x1 +; CHECK-NEXT: add x11, x0, x1 +; CHECK-NEXT: mov w10, #33554432 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: uunpklo z3.d, z0.s -; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: mov w9, #33554432 +; CHECK-NEXT: madd x8, x8, x9, x11 +; CHECK-NEXT: uunpklo z2.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: index z1.d, #0, x9 -; CHECK-NEXT: mov z2.d, x8 -; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] -; CHECK-NEXT: add z2.d, z1.d, z2.d -; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] +; CHECK-NEXT: index z1.d, #0, x10 +; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -138,18 +136,16 @@ ; CHECK-NEXT: mov x9, #-2 ; CHECK-NEXT: lsr x8, x8, #4 ; CHECK-NEXT: movk x9, #64511, lsl #16 -; CHECK-NEXT: add x10, x0, x1 +; CHECK-NEXT: add x11, x0, x1 +; CHECK-NEXT: mov x10, #-33554433 +; CHECK-NEXT: madd x8, x8, x9, x11 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: mov x9, #-33554433 -; CHECK-NEXT: uunpklo z3.d, z0.s +; CHECK-NEXT: uunpklo z2.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b +; CHECK-NEXT: index z1.d, #0, x10 ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: index z1.d, #0, x9 -; CHECK-NEXT: mov z2.d, x8 -; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] -; CHECK-NEXT: add z2.d, z1.d, z2.d -; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] +; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer @@ -170,18 +166,16 @@ ; CHECK-NEXT: rdvl x8, #1 ; CHECK-NEXT: mov x9, #-9223372036854775808 ; CHECK-NEXT: lsr x8, x8, #4 -; CHECK-NEXT: add x10, x0, x1 +; CHECK-NEXT: add x11, x0, x1 +; CHECK-NEXT: mov x10, #4611686018427387904 ; CHECK-NEXT: punpklo p1.h, p0.b -; CHECK-NEXT: uunpklo z3.d, z0.s -; CHECK-NEXT: mul x8, x8, x9 -; CHECK-NEXT: mov x9, #4611686018427387904 +; CHECK-NEXT: madd x8, x8, x9, x11 +; CHECK-NEXT: uunpklo z2.d, z0.s ; CHECK-NEXT: punpkhi p0.h, p0.b ; CHECK-NEXT: uunpkhi z0.d, z0.s -; CHECK-NEXT: index z1.d, #0, x9 -; CHECK-NEXT: mov z2.d, x8 -; CHECK-NEXT: st1b { z3.d }, p1, [x10, z1.d] -; CHECK-NEXT: add z2.d, z1.d, z2.d -; CHECK-NEXT: st1b { z0.d }, p0, [x10, z2.d] +; CHECK-NEXT: index z1.d, #0, x10 +; CHECK-NEXT: st1b { z2.d }, p1, [x11, z1.d] +; CHECK-NEXT: st1b { z0.d }, p0, [x8, z1.d] ; CHECK-NEXT: ret %t0 = insertelement undef, i64 %offset, i32 0 %t1 = shufflevector %t0, undef, zeroinitializer Index: llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll =================================================================== --- llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll +++ llvm/test/CodeGen/RISCV/rvv/mscatter-combine.ll @@ -10,25 +10,23 @@ ; RV32-LABEL: complex_gep: ; RV32: # %bb.0: ; RV32-NEXT: vsetvli a1, zero, e32, m1, ta, mu -; RV32-NEXT: vmv.v.x v10, a0 -; RV32-NEXT: vnsrl.wi v11, v8, 0 -; RV32-NEXT: li a0, 48 -; RV32-NEXT: vmadd.vx v11, a0, v10 -; RV32-NEXT: vmv.v.i v8, 0 -; RV32-NEXT: li a0, 28 -; RV32-NEXT: vsoxei32.v v8, (a0), v11, v0.t +; RV32-NEXT: vnsrl.wi v10, v8, 0 +; RV32-NEXT: li a1, 48 +; RV32-NEXT: vmul.vx v8, v10, a1 +; RV32-NEXT: addi a0, a0, 28 +; RV32-NEXT: vmv.v.i v9, 0 +; RV32-NEXT: vsoxei32.v v9, (a0), v8, v0.t ; RV32-NEXT: ret ; ; RV64-LABEL: complex_gep: ; RV64: # %bb.0: -; RV64-NEXT: vsetvli a1, zero, e64, m2, ta, mu -; RV64-NEXT: vmv.v.x v10, a0 -; RV64-NEXT: li a0, 56 -; RV64-NEXT: vmacc.vx v10, a0, v8 +; RV64-NEXT: li a1, 56 +; RV64-NEXT: vsetvli a2, zero, e64, m2, ta, mu +; RV64-NEXT: vmul.vx v8, v8, a1 +; RV64-NEXT: addi a0, a0, 32 ; RV64-NEXT: vsetvli zero, zero, e32, m1, ta, mu -; RV64-NEXT: vmv.v.i v8, 0 -; RV64-NEXT: li a0, 32 -; RV64-NEXT: vsoxei64.v v8, (a0), v10, v0.t +; RV64-NEXT: vmv.v.i v10, 0 +; RV64-NEXT: vsoxei64.v v10, (a0), v8, v0.t ; RV64-NEXT: ret %gep = getelementptr inbounds %struct, ptr %p, %vec.ind, i32 5 call void @llvm.masked.scatter.nxv2i32.nxv2p0( zeroinitializer, %gep, i32 8, %m)