diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -127,6 +127,41 @@ return std::make_pair(StartVal, Stride); } +static std::pair matchStridedStart(Value *Start, + IRBuilder<> &Builder) { + // Base case, start is a strided constant. + auto *StartC = dyn_cast(Start); + if (StartC) + return matchStridedConstant(StartC); + + // Not a constant, maybe it's a strided constant with a splat added to it. + auto *BO = dyn_cast(Start); + if (!BO || BO->getOpcode() != Instruction::Add) + return std::make_pair(nullptr, nullptr); + + // Look for an operand that is splatted. + unsigned OtherIndex = 1; + Value *Splat = getSplatValue(BO->getOperand(0)); + if (!Splat) { + Splat = getSplatValue(BO->getOperand(1)); + OtherIndex = 0; + } + if (!Splat) + return std::make_pair(nullptr, nullptr); + + Value *Stride; + std::tie(Start, Stride) = matchStridedStart(BO->getOperand(OtherIndex), + Builder); + if (!Start) + return std::make_pair(nullptr, nullptr); + + // Add the splat value to the start. + Builder.SetInsertPoint(BO); + Builder.SetCurrentDebugLocation(DebugLoc()); + Start = Builder.CreateAdd(Start, Splat); + return std::make_pair(Start, Stride); +} + // Recursively, walk about the use-def chain until we find a Phi with a strided // start value. Build and update a scalar recurrence as we unwind the recursion. // We also update the Stride as we unwind. Our goal is to move all of the @@ -161,12 +196,7 @@ if (!Step) return false; - // Start should be a strided constant. - auto *StartC = dyn_cast(Start); - if (!StartC) - return false; - - std::tie(Start, Stride) = matchStridedConstant(StartC); + std::tie(Start, Stride) = matchStridedStart(Start, Builder); if (!Start) return false; assert(Stride != nullptr); diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vector-strided-load-store.ll @@ -978,3 +978,173 @@ } declare void @llvm.masked.scatter.v2p0i32.v2p0p0i32(<2 x i32*>, <2 x i32**>, i32 immarg, <2 x i1>) + +define void @strided_load_startval_add_with_splat(i8* noalias nocapture %0, i8* noalias nocapture readonly %1, i32 signext %2) { +; CHECK-LABEL: @strided_load_startval_add_with_splat( +; CHECK-NEXT: [[TMP4:%.*]] = icmp eq i32 [[TMP2:%.*]], 1024 +; CHECK-NEXT: br i1 [[TMP4]], label [[TMP31:%.*]], label [[TMP5:%.*]] +; CHECK: 5: +; CHECK-NEXT: [[TMP6:%.*]] = sext i32 [[TMP2]] to i64 +; CHECK-NEXT: [[TMP7:%.*]] = sub i32 1023, [[TMP2]] +; CHECK-NEXT: [[TMP8:%.*]] = zext i32 [[TMP7]] to i64 +; CHECK-NEXT: [[TMP9:%.*]] = add nuw nsw i64 [[TMP8]], 1 +; CHECK-NEXT: [[TMP10:%.*]] = icmp ult i32 [[TMP7]], 31 +; CHECK-NEXT: br i1 [[TMP10]], label [[TMP29:%.*]], label [[TMP11:%.*]] +; CHECK: 11: +; CHECK-NEXT: [[TMP12:%.*]] = and i64 [[TMP9]], 8589934560 +; CHECK-NEXT: [[TMP13:%.*]] = add nsw i64 [[TMP12]], [[TMP6]] +; CHECK-NEXT: [[TMP14:%.*]] = add i64 0, [[TMP6]] +; CHECK-NEXT: [[START:%.*]] = mul i64 [[TMP14]], 5 +; CHECK-NEXT: br label [[TMP15:%.*]] +; CHECK: 15: +; CHECK-NEXT: [[TMP16:%.*]] = phi i64 [ 0, [[TMP11]] ], [ [[TMP25:%.*]], [[TMP15]] ] +; CHECK-NEXT: [[DOTSCALAR:%.*]] = phi i64 [ [[START]], [[TMP11]] ], [ [[DOTSCALAR1:%.*]], [[TMP15]] ] +; CHECK-NEXT: [[TMP17:%.*]] = add i64 [[TMP16]], [[TMP6]] +; CHECK-NEXT: [[TMP18:%.*]] = getelementptr i8, i8* [[TMP1:%.*]], i64 [[DOTSCALAR]] +; CHECK-NEXT: [[TMP19:%.*]] = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0i8.i64(<32 x i8> undef, i8* [[TMP18]], i64 5, <32 x i1> ) +; CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds i8, i8* [[TMP0:%.*]], i64 [[TMP17]] +; CHECK-NEXT: [[TMP21:%.*]] = bitcast i8* [[TMP20]] to <32 x i8>* +; CHECK-NEXT: [[TMP22:%.*]] = load <32 x i8>, <32 x i8>* [[TMP21]], align 1 +; CHECK-NEXT: [[TMP23:%.*]] = add <32 x i8> [[TMP22]], [[TMP19]] +; CHECK-NEXT: [[TMP24:%.*]] = bitcast i8* [[TMP20]] to <32 x i8>* +; CHECK-NEXT: store <32 x i8> [[TMP23]], <32 x i8>* [[TMP24]], align 1 +; CHECK-NEXT: [[TMP25]] = add nuw i64 [[TMP16]], 32 +; CHECK-NEXT: [[DOTSCALAR1]] = add i64 [[DOTSCALAR]], 160 +; CHECK-NEXT: [[TMP26:%.*]] = icmp eq i64 [[TMP25]], [[TMP12]] +; CHECK-NEXT: br i1 [[TMP26]], label [[TMP27:%.*]], label [[TMP15]] +; CHECK: 27: +; CHECK-NEXT: [[TMP28:%.*]] = icmp eq i64 [[TMP9]], [[TMP12]] +; CHECK-NEXT: br i1 [[TMP28]], label [[TMP31]], label [[TMP29]] +; CHECK: 29: +; CHECK-NEXT: [[TMP30:%.*]] = phi i64 [ [[TMP6]], [[TMP5]] ], [ [[TMP13]], [[TMP27]] ] +; CHECK-NEXT: br label [[TMP32:%.*]] +; CHECK: 31: +; CHECK-NEXT: ret void +; CHECK: 32: +; CHECK-NEXT: [[TMP33:%.*]] = phi i64 [ [[TMP40:%.*]], [[TMP32]] ], [ [[TMP30]], [[TMP29]] ] +; CHECK-NEXT: [[TMP34:%.*]] = mul nsw i64 [[TMP33]], 5 +; CHECK-NEXT: [[TMP35:%.*]] = getelementptr inbounds i8, i8* [[TMP1]], i64 [[TMP34]] +; CHECK-NEXT: [[TMP36:%.*]] = load i8, i8* [[TMP35]], align 1 +; CHECK-NEXT: [[TMP37:%.*]] = getelementptr inbounds i8, i8* [[TMP0]], i64 [[TMP33]] +; CHECK-NEXT: [[TMP38:%.*]] = load i8, i8* [[TMP37]], align 1 +; CHECK-NEXT: [[TMP39:%.*]] = add i8 [[TMP38]], [[TMP36]] +; CHECK-NEXT: store i8 [[TMP39]], i8* [[TMP37]], align 1 +; CHECK-NEXT: [[TMP40]] = add nsw i64 [[TMP33]], 1 +; CHECK-NEXT: [[TMP41:%.*]] = trunc i64 [[TMP40]] to i32 +; CHECK-NEXT: [[TMP42:%.*]] = icmp eq i32 [[TMP41]], 1024 +; CHECK-NEXT: br i1 [[TMP42]], label [[TMP31]], label [[TMP32]] +; +; CHECK-ASM-LABEL: strided_load_startval_add_with_splat: +; CHECK-ASM: # %bb.0: +; CHECK-ASM-NEXT: li a3, 1024 +; CHECK-ASM-NEXT: beq a2, a3, .LBB12_7 +; CHECK-ASM-NEXT: # %bb.1: +; CHECK-ASM-NEXT: li a3, 1023 +; CHECK-ASM-NEXT: subw a4, a3, a2 +; CHECK-ASM-NEXT: li a5, 31 +; CHECK-ASM-NEXT: mv a3, a2 +; CHECK-ASM-NEXT: bltu a4, a5, .LBB12_5 +; CHECK-ASM-NEXT: # %bb.2: +; CHECK-ASM-NEXT: slli a3, a4, 32 +; CHECK-ASM-NEXT: srli a3, a3, 32 +; CHECK-ASM-NEXT: addi a6, a3, 1 +; CHECK-ASM-NEXT: andi a7, a6, -32 +; CHECK-ASM-NEXT: add a3, a7, a2 +; CHECK-ASM-NEXT: slli a4, a2, 2 +; CHECK-ASM-NEXT: add a4, a4, a2 +; CHECK-ASM-NEXT: add a2, a0, a2 +; CHECK-ASM-NEXT: add a4, a1, a4 +; CHECK-ASM-NEXT: li t0, 32 +; CHECK-ASM-NEXT: li t1, 5 +; CHECK-ASM-NEXT: mv a5, a7 +; CHECK-ASM-NEXT: .LBB12_3: # =>This Inner Loop Header: Depth=1 +; CHECK-ASM-NEXT: vsetvli zero, t0, e8, m1, ta, mu +; CHECK-ASM-NEXT: vlse8.v v8, (a4), t1 +; CHECK-ASM-NEXT: vle8.v v9, (a2) +; CHECK-ASM-NEXT: vadd.vv v8, v9, v8 +; CHECK-ASM-NEXT: vse8.v v8, (a2) +; CHECK-ASM-NEXT: addi a5, a5, -32 +; CHECK-ASM-NEXT: addi a2, a2, 32 +; CHECK-ASM-NEXT: addi a4, a4, 160 +; CHECK-ASM-NEXT: bnez a5, .LBB12_3 +; CHECK-ASM-NEXT: # %bb.4: +; CHECK-ASM-NEXT: beq a6, a7, .LBB12_7 +; CHECK-ASM-NEXT: .LBB12_5: +; CHECK-ASM-NEXT: slli a2, a3, 2 +; CHECK-ASM-NEXT: add a2, a2, a3 +; CHECK-ASM-NEXT: add a1, a1, a2 +; CHECK-ASM-NEXT: li a6, 1024 +; CHECK-ASM-NEXT: .LBB12_6: # =>This Inner Loop Header: Depth=1 +; CHECK-ASM-NEXT: lb a4, 0(a1) +; CHECK-ASM-NEXT: add a5, a0, a3 +; CHECK-ASM-NEXT: lb a2, 0(a5) +; CHECK-ASM-NEXT: addw a2, a2, a4 +; CHECK-ASM-NEXT: sb a2, 0(a5) +; CHECK-ASM-NEXT: addiw a2, a3, 1 +; CHECK-ASM-NEXT: addi a3, a3, 1 +; CHECK-ASM-NEXT: addi a1, a1, 5 +; CHECK-ASM-NEXT: bne a2, a6, .LBB12_6 +; CHECK-ASM-NEXT: .LBB12_7: +; CHECK-ASM-NEXT: ret + %4 = icmp eq i32 %2, 1024 + br i1 %4, label %36, label %5 + +5: ; preds = %3 + %6 = sext i32 %2 to i64 + %7 = sub i32 1023, %2 + %8 = zext i32 %7 to i64 + %9 = add nuw nsw i64 %8, 1 + %10 = icmp ult i32 %7, 31 + br i1 %10, label %34, label %11 + +11: ; preds = %5 + %12 = and i64 %9, 8589934560 + %13 = add nsw i64 %12, %6 + %14 = insertelement <32 x i64> poison, i64 %6, i64 0 + %15 = shufflevector <32 x i64> %14, <32 x i64> poison, <32 x i32> zeroinitializer + %16 = add <32 x i64> %15, + br label %17 + +17: ; preds = %17, %11 + %18 = phi i64 [ 0, %11 ], [ %29, %17 ] + %19 = phi <32 x i64> [ %16, %11 ], [ %30, %17 ] + %20 = add i64 %18, %6 + %21 = mul nsw <32 x i64> %19, + %22 = getelementptr inbounds i8, i8* %1, <32 x i64> %21 + %23 = call <32 x i8> @llvm.masked.gather.v32i8.v32p0i8(<32 x i8*> %22, i32 1, <32 x i1> , <32 x i8> undef) + %24 = getelementptr inbounds i8, i8* %0, i64 %20 + %25 = bitcast i8* %24 to <32 x i8>* + %26 = load <32 x i8>, <32 x i8>* %25, align 1 + %27 = add <32 x i8> %26, %23 + %28 = bitcast i8* %24 to <32 x i8>* + store <32 x i8> %27, <32 x i8>* %28, align 1 + %29 = add nuw i64 %18, 32 + %30 = add <32 x i64> %19, + %31 = icmp eq i64 %29, %12 + br i1 %31, label %32, label %17 + +32: ; preds = %17 + %33 = icmp eq i64 %9, %12 + br i1 %33, label %36, label %34 + +34: ; preds = %5, %32 + %35 = phi i64 [ %6, %5 ], [ %13, %32 ] + br label %37 + +36: ; preds = %37, %32, %3 + ret void + +37: ; preds = %34, %37 + %38 = phi i64 [ %45, %37 ], [ %35, %34 ] + %39 = mul nsw i64 %38, 5 + %40 = getelementptr inbounds i8, i8* %1, i64 %39 + %41 = load i8, i8* %40, align 1 + %42 = getelementptr inbounds i8, i8* %0, i64 %38 + %43 = load i8, i8* %42, align 1 + %44 = add i8 %43, %41 + store i8 %44, i8* %42, align 1 + %45 = add nsw i64 %38, 1 + %46 = trunc i64 %45 to i32 + %47 = icmp eq i32 %46, 1024 + br i1 %47, label %36, label %37 +}