diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -236,9 +236,6 @@ case Instruction::Add: break; case Instruction::Shl: - // Only support shift by constant. - if (!isa(BO->getOperand(1))) - return false; break; case Instruction::Mul: break; @@ -251,7 +248,8 @@ Index = cast(BO->getOperand(0)); OtherOp = BO->getOperand(1); } else if (isa(BO->getOperand(1)) && - L->contains(cast(BO->getOperand(1)))) { + L->contains(cast(BO->getOperand(1))) && + Instruction::isCommutative(BO->getOpcode())) { Index = cast(BO->getOperand(1)); OtherOp = BO->getOperand(0); } else { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll @@ -310,6 +310,100 @@ ret void } +define void @gather_unknown_pow2(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) { +; CHECK-LABEL: @gather_unknown_pow2( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[STEP:%.*]] = shl i64 8, [[SHIFT:%.*]] +; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SHIFT]] +; CHECK-NEXT: [[TMP0:%.*]] = mul i64 [[STRIDE]], 4 +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[ENTRY]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[B:%.*]], i64 [[VEC_IND_SCALAR]] +; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.riscv.masked.strided.load.v8i32.p0.i64(<8 x i32> undef, ptr [[TMP1]], i64 [[TMP0]], <8 x i1> ) +; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1 +; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]] +; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8 +; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]] +; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 +; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; +entry: + %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0 + %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.ind = phi <8 x i64> [ , %entry ], [ %vec.ind.next, %vector.body ] + %i = shl nsw <8 x i64> %vec.ind, %.splat + %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i + %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> , <8 x i32> undef) + %i2 = getelementptr inbounds i32, ptr %A, i64 %index + %wide.load = load <8 x i32>, ptr %i2, align 1 + %i4 = add <8 x i32> %wide.load, %wide.masked.gather + store <8 x i32> %i4, ptr %i2, align 1 + %index.next = add nuw i64 %index, 8 + %vec.ind.next = add <8 x i64> %vec.ind, + %i6 = icmp eq i64 %index.next, 1024 + br i1 %i6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +} + +define void @negative_shl_non_commute(ptr noalias nocapture %A, ptr noalias nocapture readonly %B, i64 %shift) { +; CHECK-LABEL: @negative_shl_non_commute( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[SHIFT:%.*]], i64 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <8 x i64> [[DOTSPLATINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer +; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] +; CHECK: vector.body: +; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_IND:%.*]] = phi <8 x i64> [ , [[ENTRY]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[I:%.*]] = shl nsw <8 x i64> [[DOTSPLAT]], [[VEC_IND]] +; CHECK-NEXT: [[I1:%.*]] = getelementptr inbounds i32, ptr [[B:%.*]], <8 x i64> [[I]] +; CHECK-NEXT: [[WIDE_MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> [[I1]], i32 4, <8 x i1> , <8 x i32> undef) +; CHECK-NEXT: [[I2:%.*]] = getelementptr inbounds i32, ptr [[A:%.*]], i64 [[INDEX]] +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <8 x i32>, ptr [[I2]], align 1 +; CHECK-NEXT: [[I4:%.*]] = add <8 x i32> [[WIDE_LOAD]], [[WIDE_MASKED_GATHER]] +; CHECK-NEXT: store <8 x i32> [[I4]], ptr [[I2]], align 1 +; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 8 +; CHECK-NEXT: [[VEC_IND_NEXT]] = add <8 x i64> [[VEC_IND]], +; CHECK-NEXT: [[I6:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024 +; CHECK-NEXT: br i1 [[I6]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK: for.cond.cleanup: +; CHECK-NEXT: ret void +; +entry: + %.splatinsert = insertelement <8 x i64> poison, i64 %shift, i64 0 + %.splat = shufflevector <8 x i64> %.splatinsert, <8 x i64> poison, <8 x i32> zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %vec.ind = phi <8 x i64> [ , %entry ], [ %vec.ind.next, %vector.body ] + %i = shl nsw <8 x i64> %.splat, %vec.ind + %i1 = getelementptr inbounds i32, ptr %B, <8 x i64> %i + %wide.masked.gather = call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %i1, i32 4, <8 x i1> , <8 x i32> undef) + %i2 = getelementptr inbounds i32, ptr %A, i64 %index + %wide.load = load <8 x i32>, ptr %i2, align 1 + %i4 = add <8 x i32> %wide.load, %wide.masked.gather + store <8 x i32> %i4, ptr %i2, align 1 + %index.next = add nuw i64 %index, 8 + %vec.ind.next = add <8 x i64> %vec.ind, + %i6 = icmp eq i64 %index.next, 1024 + br i1 %i6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +} + ;void scatter_pow2(signed char * __restrict A, signed char * __restrict B) { ; for (int i = 0; i < 1024; ++i) ; A[i * 4] += B[i];