diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -362,7 +362,6 @@ // * Single constant active lane -> store // * Adjacent vector addresses -> masked.store // * Narrow store width by halfs excluding zero/undef lanes -// * Vector splat address w/known mask -> scalar store // * Vector incrementing address -> vector masked store Instruction *InstCombinerImpl::simplifyMaskedScatter(IntrinsicInst &II) { auto *ConstMask = dyn_cast(II.getArgOperand(3)); @@ -373,6 +372,27 @@ if (ConstMask->isNullValue()) return eraseInstFromFunction(II); + // Vector splat address w/known mask -> scalar store + // Mask is all one and there is splat of destination so this is a + // uniform store and can be folded. + if (ConstMask->isAllOnesValue()) + if (auto *Splat = getSplatValue(II.getArgOperand(1))) { + Align Alignment = cast(II.getArgOperand(2))->getAlignValue(); + VectorType *WideLoad = + dyn_cast(II.getArgOperand(1)->getType()); + ElementCount VF = WideLoad->getElementCount(); + Constant *EC = + ConstantInt::get(Builder.getInt32Ty(), VF.getKnownMinValue()); + Value *RunTimeVF = Builder.CreateVScale(EC); + // LastLane = RunTimeVF - 1 + Value *LastLane = Builder.CreateSub(RunTimeVF, Builder.getInt32(1)); + Value *Extract = + Builder.CreateExtractElement(II.getArgOperand(0), LastLane); + StoreInst *S = new StoreInst(Extract, Splat, false, Alignment); + S->copyMetadata(II); + return S; + } + if (isa(ConstMask->getType())) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/vscale_masked_intrinsics.ll b/llvm/test/Transforms/InstCombine/vscale_masked_intrinsics.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/vscale_masked_intrinsics.ll @@ -0,0 +1,63 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -instcombine -S < %s | FileCheck %s + + +define void @valid_inv_store_i16(i16* noalias %dst, * noalias readonly %src) #0 { +; CHECK-LABEL: @valid_inv_store_i16( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , * [[SRC:%.*]], align 2 +; CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[TMP1:%.*]] = shl i32 [[TMP0]], 2 +; CHECK-NEXT: [[TMP2:%.*]] = add i32 [[TMP1]], -1 +; CHECK-NEXT: [[TMP3:%.*]] = extractelement [[WIDE_LOAD]], i32 [[TMP2]] +; CHECK-NEXT: store i16 [[TMP3]], i16* [[DST:%.*]], align 2 +; CHECK-NEXT: ret void +; +entry: + %broadcast.splatinsert = insertelement poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %wide.load = load , * %src, align 2 + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( %wide.load, %broadcast.splat, i32 2, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) + ret void +} + +;; The destinatin address is not a splat +define void @invalid_addr_inv_store_i16(i16* noalias %dst, * noalias readonly %src) #0 { +; CHECK-LABEL: @invalid_addr_inv_store_i16( +; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement poison, i16* [[DST:%.*]], i32 1 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector [[INSERT_ELT]], poison, zeroinitializer +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , * [[SRC:%.*]], align 2 +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( [[WIDE_LOAD]], [[BROADCAST_SPLAT]], i32 2, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) +; CHECK-NEXT: ret void +; + %insert.elt = insertelement poison, i16* %dst, i32 1 + %broadcast.splat = shufflevector %insert.elt, poison, zeroinitializer + %wide.load = load , * %src, align 2 + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( %wide.load, %broadcast.splat, i32 2, shufflevector ( insertelement ( poison, i1 true, i32 0), poison, zeroinitializer)) + ret void +} + +;; The mask is not all 1 +define void @invalid_mask_inv_store_i16(i16* noalias %dst, * noalias readonly %src) #0 { +; CHECK-LABEL: @invalid_mask_inv_store_i16( +; CHECK-NEXT: [[INSERT_ELT:%.*]] = insertelement poison, i16* [[DST:%.*]], i32 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector [[INSERT_ELT]], poison, zeroinitializer +; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load , * [[SRC:%.*]], align 2 +; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( [[WIDE_LOAD]], [[BROADCAST_SPLAT]], i32 2, shufflevector ( insertelement ( poison, i1 true, i32 1), poison, zeroinitializer)) +; CHECK-NEXT: ret void +; + %insert.elt = insertelement poison, i16* %dst, i32 0 + %broadcast.splat = shufflevector %insert.elt, poison, zeroinitializer + %wide.load = load , * %src, align 2 + call void @llvm.masked.scatter.nxv4i16.nxv4p0i16( %wide.load, %broadcast.splat, i32 2, shufflevector ( insertelement ( poison, i1 true, i32 1), poison, zeroinitializer)) + ret void +} + + + +; Function Attrs: nofree nosync nounwind willreturn writeonly +declare void @llvm.masked.scatter.nxv4i16.nxv4p0i16(, , i32 immarg, ) + + +attributes #0 = { "target-features"="+sve,+sve" } +