diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -352,9 +352,27 @@ // * Dereferenceable address & few lanes -> scalarize speculative load/selects // * Adjacent vector addresses -> masked.load // * Narrow width by halfs excluding zero/undef lanes -// * Vector splat address w/known mask -> scalar load // * Vector incrementing address -> vector masked load Instruction *InstCombinerImpl::simplifyMaskedGather(IntrinsicInst &II) { + auto *ConstMask = dyn_cast(II.getArgOperand(2)); + if (!ConstMask) + return nullptr; + + // Vector splat address w/known mask -> scalar load + // Fold the gather to load the source vector first lane + // because it is reloading the same value each time + if (ConstMask->isAllOnesValue()) + if (auto *SplatPtr = getSplatValue(II.getArgOperand(0))) { + auto *VecTy = cast(II.getType()); + const Align Alignment = + cast(II.getArgOperand(1))->getAlignValue(); + LoadInst *L = Builder.CreateAlignedLoad(VecTy->getElementType(), SplatPtr, + Alignment, "load.scalar"); + Value *Shuf = + Builder.CreateVectorSplat(VecTy->getElementCount(), L, "broadcast"); + return replaceInstUsesWith(II, cast(Shuf)); + } + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -376,3 +376,65 @@ ; Function Attrs: declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>) declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(, , i32 immarg, ) + +; Test gathers that can be simplified to scalar load + splat + +;; Splat address and all active mask +define @gather_nxv2i64_uniform_ptrs_all_active_mask(i64* %src) { +; CHECK-LABEL: @gather_nxv2i64_uniform_ptrs_all_active_mask( +; CHECK-NEXT: [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement poison, i64 [[LOAD_SCALAR]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector [[BROADCAST_SPLATINSERT1]], poison, zeroinitializer +; CHECK-NEXT: ret [[BROADCAST_SPLAT2]] +; + %broadcast.splatinsert = insertelement poison, i64 *%src, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %res = call @llvm.masked.gather.nxv2i64( %broadcast.splat, i32 8, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer), undef) + ret %res +} + +define <2 x i64> @gather_v2i64_uniform_ptrs_all_active_mask(i64* %src) { +; CHECK-LABEL: @gather_v2i64_uniform_ptrs_all_active_mask( +; CHECK-NEXT: [[LOAD_SCALAR:%.*]] = load i64, i64* [[SRC:%.*]], align 8 +; CHECK-NEXT: [[BROADCAST_SPLATINSERT1:%.*]] = insertelement <2 x i64> poison, i64 [[LOAD_SCALAR]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT2:%.*]] = shufflevector <2 x i64> [[BROADCAST_SPLATINSERT1]], <2 x i64> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: ret <2 x i64> [[BROADCAST_SPLAT2]] +; + %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0 + %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer + %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> , <2 x i64> undef) + ret <2 x i64> %res +} + +; Negative gather tests + +;; Vector of pointers is not a splat. +define <2 x i64> @negative_gather_v2i64_non_uniform_ptrs_all_active_mask(<2 x i64*> %inVal, i64* %src ) { +; CHECK-LABEL: @negative_gather_v2i64_non_uniform_ptrs_all_active_mask( +; CHECK-NEXT: [[INSERT_VALUE:%.*]] = insertelement <2 x i64*> [[INVAL:%.*]], i64* [[SRC:%.*]], i64 1 +; CHECK-NEXT: [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[INSERT_VALUE]], i32 8, <2 x i1> , <2 x i64> undef) +; CHECK-NEXT: ret <2 x i64> [[RES]] +; + %insert.value = insertelement <2 x i64*> %inVal, i64 *%src, i32 1 + %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %insert.value, i32 8, <2 x i1>, <2 x i64> undef) + ret <2 x i64> %res +} + +;; Unknown mask value +define <2 x i64> @negative_gather_v2i64_uniform_ptrs_no_all_active_mask(i64* %src, <2 x i1> %mask) { +; CHECK-LABEL: @negative_gather_v2i64_uniform_ptrs_no_all_active_mask( +; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <2 x i64*> poison, i64* [[SRC:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <2 x i64*> [[BROADCAST_SPLATINSERT]], <2 x i64*> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[RES:%.*]] = call <2 x i64> @llvm.masked.gather.v2i64.v2p0i64(<2 x i64*> [[BROADCAST_SPLAT]], i32 8, <2 x i1> [[MASK:%.*]], <2 x i64> undef) +; CHECK-NEXT: ret <2 x i64> [[RES]] +; + %broadcast.splatinsert = insertelement <2 x i64*> poison, i64 *%src, i32 0 + %broadcast.splat = shufflevector <2 x i64*> %broadcast.splatinsert, <2 x i64*> poison, <2 x i32> zeroinitializer + %res = call <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*> %broadcast.splat, i32 8, <2 x i1> %mask, <2 x i64> undef) + ret <2 x i64> %res +} + +; Function Attrs: +declare @llvm.masked.gather.nxv2i64(, i32, , ) +declare <2 x i64> @llvm.masked.gather.v2i64(<2 x i64*>, i32, <2 x i1>, <2 x i64>) +