diff --git a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp @@ -21,9 +21,11 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsRISCV.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "riscv-gather-scatter-lowering" @@ -139,6 +141,12 @@ if (StartC) return matchStridedConstant(StartC); + // Base case, start is a stepvector + if (match(Start, m_Intrinsic())) { + auto *Ty = Start->getType()->getScalarType(); + return std::make_pair(ConstantInt::get(Ty, 0), ConstantInt::get(Ty, 1)); + } + // Not a constant, maybe it's a strided constant with a splat added to it. auto *BO = dyn_cast(Start); if (!BO || BO->getOpcode() != Instruction::Add) @@ -482,11 +490,9 @@ for (BasicBlock &BB : F) { for (Instruction &I : BB) { IntrinsicInst *II = dyn_cast(&I); - if (II && II->getIntrinsicID() == Intrinsic::masked_gather && - isa(II->getType())) { + if (II && II->getIntrinsicID() == Intrinsic::masked_gather) { Gathers.push_back(II); - } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && - isa(II->getArgOperand(0)->getType())) { + } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter) { Scatters.push_back(II); } } diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll --- a/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll @@ -10,21 +10,18 @@ ; CHECK-NEXT: vector.ph: ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP0]], i64 0 -; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_IND:%.*]] = phi [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ] ; CHECK-NEXT: [[ACCUM:%.*]] = phi [ zeroinitializer, [[VECTOR_PH]] ], [ [[ACCUM_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], [[VEC_IND]], i32 3 -; CHECK-NEXT: [[GATHER:%.*]] = call @llvm.masked.gather.nxv1i64.nxv1p0( [[TMP2]], i32 8, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), undef) +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3 +; CHECK-NEXT: [[GATHER:%.*]] = call @llvm.riscv.masked.strided.load.nxv1i64.p0.i64( undef, ptr [[TMP1]], i64 16, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) ; CHECK-NEXT: [[ACCUM_NEXT]] = add [[ACCUM]], [[GATHER]] ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] -; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret [[ACCUM_NEXT]] ; @@ -57,19 +54,16 @@ ; CHECK-NEXT: vector.ph: ; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64 ; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64() -; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.experimental.stepvector.nxv1i64() -; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[TMP0]], i64 0 -; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: ; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[VEC_IND:%.*]] = phi [ [[TMP1]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ] -; CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds [[STRUCT_FOO:%.*]], ptr [[A:%.*]], [[VEC_IND]], i32 3 -; CHECK-NEXT: tail call void @llvm.masked.scatter.nxv1i64.nxv1p0( zeroinitializer, [[TMP2]], i32 8, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) +; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3 +; CHECK-NEXT: call void @llvm.riscv.masked.strided.store.nxv1i64.p0.i64( zeroinitializer, ptr [[TMP1]], i64 16, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) ; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]] -; CHECK-NEXT: [[VEC_IND_NEXT]] = add [[VEC_IND]], [[DOTSPLAT]] -; CHECK-NEXT: [[TMP3:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] -; CHECK-NEXT: br i1 [[TMP3]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] +; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[TMP0]] +; CHECK-NEXT: [[TMP2:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]] +; CHECK-NEXT: br i1 [[TMP2]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]] ; CHECK: for.cond.cleanup: ; CHECK-NEXT: ret void ;