diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -362,7 +362,6 @@ // * Single constant active lane -> store // * Adjacent vector addresses -> masked.store // * Narrow store width by halfs excluding zero/undef lanes -// * Vector splat address w/known mask -> scalar store // * Vector incrementing address -> vector masked store Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { auto *ConstMask = dyn_cast(II.getArgOperand(3)); @@ -373,6 +372,34 @@ if (ConstMask->isNullValue()) return eraseInstFromFunction(II); + // Vector splat address -> scalar store + if (auto *SplatPtr = getSplatValue(II.getArgOperand(1))) { + // scatter(splat(value), splat(ptr), non-zero-mask) -> store value, ptr + if (auto *SplatValue = getSplatValue(II.getArgOperand(0))) { + Align Alignment = cast(II.getArgOperand(2))->getAlignValue(); + StoreInst *S = + new StoreInst(SplatValue, SplatPtr, /*IsVolatile=*/false, Alignment); + S->copyMetadata(II); + return S; + } + // scatter(vector, splat(ptr), splat(true)) -> store extract(vector, + // lastlane), ptr + if (ConstMask->isAllOnesValue()) { + Align Alignment = cast(II.getArgOperand(2))->getAlignValue(); + VectorType *WideLoadTy = cast(II.getArgOperand(1)->getType()); + ElementCount VF = WideLoadTy->getElementCount(); + Constant *EC = + ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue()); + Value *RunTimeVF = VF.isScalable() ? Builder.CreateVScale(EC) : EC; + Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1)); + Value *Extract = + Builder.CreateExtractElement(II.getArgOperand(0), LastLane); + StoreInst *S = + new StoreInst(Extract, SplatPtr, /*IsVolatile=*/false, Alignment); + S->copyMetadata(II); + return S; + } + } if (isa(ConstMask->getType())) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll --- a/llvm/test/Transforms/InstCombine/masked_intrinsics.ll +++ b/llvm/test/Transforms/InstCombine/masked_intrinsics.ll @@ -269,3 +269,110 @@ call void @llvm.masked.scatter.v2f64.v2p0f64(<2 x double> %valvec2, <2 x double*> %ptrs, i32 8, <2 x i1> ) ret void } + + +; Test scatters that can be simplified to scalar stores. + +;; Value splat (mask is not used) +define void @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask(i16* %dst, i16 %val) { +; CHECK-LABEL: @scatter_v4i16_uniform_vals_uniform_ptrs_no_all_active_mask( +; CHECK-NEXT: entry: +; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2 +; CHECK-NEXT: ret void +; +entry: + %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer + %broadcast.value = insertelement <4 x i16> poison, i16 %val, i32 0 + %broadcast.splatvalue = shufflevector <4 x i16> %broadcast.value, <4 x i16> poison, <4 x i32> zeroinitializer + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %broadcast.splatvalue, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> ) + ret void +} + +define void @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, i16 %val) { +; CHECK-LABEL: @scatter_nxv4i16_uniform_vals_uniform_ptrs_all_active_mask( +; CHECK-NEXT: entry: +; CHECK-NEXT: store i16 [[VAL:%.*]], i16* [[DST:%.*]], align 2 +; CHECK-NEXT: ret void +; +entry: + %broadcast.splatinsert = insertelement poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %broadcast.value = insertelement poison, i16 %val, i32 0 + %broadcast.splatvalue = shufflevector %broadcast.value, poison, zeroinitializer + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( %broadcast.splatvalue, %broadcast.splat, i32 2, shufflevector ( insertelement ( zeroinitializer , i1 true, i32 0), zeroinitializer, zeroinitializer)) + ret void +} + +;; The pointer is splat and mask is all active, but value is not a splat +define void @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, <4 x i16>* %src) { +; CHECK-LABEL: @scatter_v4i16_no_uniform_vals_uniform_ptrs_all_active_mask( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2 +; CHECK-NEXT: [[TMP0:%.*]] = extractelement <4 x i16> [[WIDE_LOAD]], i64 3 +; CHECK-NEXT: store i16 [[TMP0]], i16* [[DST:%.*]], align 2 +; CHECK-NEXT: ret void +; +entry: + %broadcast.splatinsert = insertelement <4 x i16*> poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector <4 x i16*> %broadcast.splatinsert, <4 x i16*> poison, <4 x i32> zeroinitializer + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2 + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> ) + ret void +} + +define void @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask(i16* %dst, * %src) { +; CHECK-LABEL: @scatter_nxv4i16_no_uniform_vals_uniform_ptrs_all_active_mask( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , * [[SRC:%.*]], align 2 +; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2 +; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[WIDE_LOAD]], i32 [[TMP2]] +; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2 +; CHECK-NEXT: ret void +; +entry: + %broadcast.splatinsert = insertelement poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %wide.load = load , * %src, align 2 + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( %wide.load, %broadcast.splat, i32 2, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) + ret void +} + +; Negative scatter tests + +;; Pointer is splat, but mask is not all active and value is not a splat +define void @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask(i16* %dst, <4 x i16>* %src) { +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_uniform_ptrs_all_inactive_mask( +; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement <4 x i16*> poison, i16* [[DST:%.*]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <4 x i16*> [[INSERT_ELT]], <4 x i16*> poison, <4 x i32> +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2 +; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST_SPLAT]], i32 2, <4 x i1> ) +; CHECK-NEXT: ret void +; + %insert.elt = insertelement <4 x i16*> poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector <4 x i16*> %insert.elt, <4 x i16*> poison, <4 x i32> zeroinitializer + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2 + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast.splat, i32 2, <4 x i1> ) + ret void +} + +;; The pointer in NOT a splat +define void @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask(<4 x i16*> %inPtr, <4 x i16>* %src) { +; CHECK-LABEL: @negative_scatter_v4i16_no_uniform_vals_no_uniform_ptrs_all_active_mask( +; CHECK-NEXT: [[BROADCAST:%.*]] = shufflevector <4 x i16*> [[INPTR:%.*]], <4 x i16*> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <4 x i16>, <4 x i16>* [[SRC:%.*]], align 2 +; CHECK-NEXT: call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> [[WIDE_LOAD]], <4 x i16*> [[BROADCAST]], i32 2, <4 x i1> ) +; CHECK-NEXT: ret void +; + %broadcast= shufflevector <4 x i16*> %inPtr, <4 x i16*> poison, <4 x i32> zeroinitializer + %wide.load = load <4 x i16>, <4 x i16>* %src, align 2 + call void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16> %wide.load, <4 x i16*> %broadcast, i32 2, <4 x i1> ) + ret void +} + + +; Function Attrs: +declare void @llvm.masked.scatter.v4i16.v4p0i16(<4 x i16>, <4 x i16*>, i32 immarg, <4 x i1>) +declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(, , i32 immarg, )