Index: llvm/lib/IR/IRBuilder.cpp =================================================================== --- llvm/lib/IR/IRBuilder.cpp +++ llvm/lib/IR/IRBuilder.cpp @@ -522,14 +522,14 @@ CallInst *IRBuilderBase::CreateMaskedGather(Value *Ptrs, Align Alignment, Value *Mask, Value *PassThru, const Twine &Name) { - auto *PtrsTy = cast(Ptrs->getType()); + auto *PtrsTy = cast(Ptrs->getType()); auto *PtrTy = cast(PtrsTy->getElementType()); - unsigned NumElts = PtrsTy->getNumElements(); - auto *DataTy = FixedVectorType::get(PtrTy->getElementType(), NumElts); + ElementCount NumElts = PtrsTy->getElementCount(); + auto *DataTy = VectorType::get(PtrTy->getElementType(), NumElts); if (!Mask) Mask = Constant::getAllOnesValue( - FixedVectorType::get(Type::getInt1Ty(Context), NumElts)); + VectorType::get(Type::getInt1Ty(Context), NumElts)); if (!PassThru) PassThru = UndefValue::get(DataTy); @@ -552,20 +552,20 @@ /// be accessed in memory CallInst *IRBuilderBase::CreateMaskedScatter(Value *Data, Value *Ptrs, Align Alignment, Value *Mask) { - auto *PtrsTy = cast(Ptrs->getType()); - auto *DataTy = cast(Data->getType()); - unsigned NumElts = PtrsTy->getNumElements(); + auto *PtrsTy = cast(Ptrs->getType()); + auto *DataTy = cast(Data->getType()); + ElementCount NumElts = PtrsTy->getElementCount(); #ifndef NDEBUG auto PtrTy = cast(PtrsTy->getElementType()); - assert(NumElts == DataTy->getNumElements() && + assert(NumElts == DataTy->getElementCount() && PtrTy->getElementType() == DataTy->getElementType() && "Incompatible pointer and data types"); #endif if (!Mask) Mask = Constant::getAllOnesValue( - FixedVectorType::get(Type::getInt1Ty(Context), NumElts)); + VectorType::get(Type::getInt1Ty(Context), NumElts)); Type *OverloadedTypes[] = {DataTy, PtrsTy}; Value *Ops[] = {Data, Ptrs, getInt32(Alignment.value()), Mask}; Index: llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h =================================================================== --- llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h +++ llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h @@ -186,11 +186,7 @@ bool getTgtMemIntrinsic(IntrinsicInst *Inst, MemIntrinsicInfo &Info); - bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) { - if (!isa(DataType) || !ST->hasSVE()) - return false; - - Type *Ty = cast(DataType)->getElementType(); + bool isLegalScalarTypeForSVEMaskedMemOp(Type *Ty) const { if (Ty->isPointerTy()) return true; @@ -205,6 +201,14 @@ return false; } + bool isLegalMaskedLoadStore(Type *DataType, Align Alignment) { + if (!isa(DataType) || !ST->hasSVE()) + return false; + + Type *Ty = cast(DataType)->getElementType(); + return isLegalScalarTypeForSVEMaskedMemOp(Ty); + } + bool isLegalMaskedLoad(Type *DataType, Align Alignment) { return isLegalMaskedLoadStore(DataType, Alignment); } @@ -213,6 +217,20 @@ return isLegalMaskedLoadStore(DataType, Alignment); } + bool isLegalMaskedGatherScatter(Type *DataType) const { + if (isa(DataType) || !ST->hasSVE()) + return false; + + return isLegalScalarTypeForSVEMaskedMemOp(DataType->getScalarType()); + } + + bool isLegalMaskedGather(Type *DataType, Align Alignment) const { + return isLegalMaskedGatherScatter(DataType); + } + bool isLegalMaskedScatter(Type *DataType, Align Alignment) const { + return isLegalMaskedGatherScatter(DataType); + } + bool isLegalNTStore(Type *DataType, Align Alignment) { // NOTE: The logic below is mostly geared towards LV, which calls it with // vectors with 2 elements. We might want to improve that, if other Index: llvm/lib/Transforms/Vectorize/LoopVectorize.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7078,7 +7078,10 @@ continue; NumAccesses = Group->getNumMembers(); - if (interleavedAccessCanBeWidened(&I, VF)) + // TODO: For now we don't consider scalable vectors here as it + // potentially requires shuffle support for reversing vectors. + if (interleavedAccessCanBeWidened(&I, VF) && + !VF.isScalable()) InterleaveCost = getInterleaveGroupCost(&I, VF); } @@ -7087,8 +7090,9 @@ ? getGatherScatterCost(&I, VF) * NumAccesses : std::numeric_limits::max(); - InstructionCost ScalarizationCost = - getMemInstScalarizationCost(&I, VF) * NumAccesses; + InstructionCost ScalarizationCost = !VF.isScalable() ? + getMemInstScalarizationCost(&I, VF) * NumAccesses + : std::numeric_limits::max(); // Choose better solution for the current VF, // write down this decision and use it during vectorization. @@ -7102,6 +7106,7 @@ Decision = CM_GatherScatter; Cost = GatherScatterCost; } else { + assert(!VF.isScalable() && "We cannot yet scalarise for scalable vectors"); Decision = CM_Scalarize; Cost = ScalarizationCost; } @@ -7451,8 +7456,12 @@ } } - assert(!VF.isScalable() && "VF is assumed to be non scalable"); - unsigned N = isScalarAfterVectorization(I, VF) ? VF.getKnownMinValue() : 1; + unsigned N; + if (isScalarAfterVectorization(I, VF)) { + assert(!VF.isScalable() && "VF is assumed to be non scalable"); + N = VF.getKnownMinValue(); + } else + N = 1; return N * TTI.getCastInstrCost(Opcode, VectorTy, SrcVecTy, CCH, CostKind, I); } Index: llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/LoopVectorize/AArch64/sve-gather-scatter.ll @@ -0,0 +1,145 @@ +; RUN: opt -loop-vectorize -dce -instcombine -mtriple aarch64-linux-gnu -mattr=+sve -S %s -o - | FileCheck %s + +define void @gather_nxv4i32(float* noalias nocapture readonly %a, i32* noalias nocapture readonly %b, float* noalias nocapture %c, i64 %n) { +; CHECK-LABEL: @gather_nxv4i32 +; CHECK: vector.body: +; CHECK: %[[IND:.*]] = load , * +; CHECK: %[[EXTIND:.*]] = sext %[[IND]] to +; CHECK: %[[PTRS:.*]] = getelementptr inbounds float, float* %a, %[[EXTIND]] +; CHECK: %[[GLOAD:.*]] = call @llvm.masked.gather.nxv4f32.nxv4p0f32( %[[PTRS]] +; CHECK: store %[[GLOAD]], * +entry: + %cmp10 = icmp sgt i64 %n, 0 + br i1 %cmp10, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.body + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %for.body + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i32, i32* %b, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %idxprom2 = sext i32 %0 to i64 + %arrayidx3 = getelementptr inbounds float, float* %a, i64 %idxprom2 + %1 = load float, float* %arrayidx3, align 4 + %arrayidx5 = getelementptr inbounds float, float* %c, i64 %indvars.iv + store float %1, float* %arrayidx5, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %n + br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body, !llvm.loop !0 +} + +define void @scatter_nxv4i32(float* noalias nocapture %a, i32* noalias nocapture readonly %b, float* noalias nocapture readonly %c, i64 %n) { +; CHECK-LABEL: @scatter_nxv4i32 +; CHECK: vector.body: +; CHECK: %[[VALS:.*]] = load +; CHECK: %[[IND:.*]] = load , * %7, align 4 +; CHECK: %[[EXTIND:.*]] = sext %[[IND]] to +; CHECK: %[[PTRS:.*]] = getelementptr inbounds float, float* %a, %[[EXTIND]] +; CHECK: call void @llvm.masked.scatter.nxv4f32.nxv4p0f32( %[[VALS]], %[[PTRS]] +entry: + %cmp10 = icmp sgt i64 %n, 0 + br i1 %cmp10, label %for.body, label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.body, %entry + ret void + +for.body: ; preds = %entry, %for.body + %indvars.iv = phi i64 [ %indvars.iv.next, %for.body ], [ 0, %entry ] + %arrayidx = getelementptr inbounds float, float* %c, i64 %indvars.iv + %0 = load float, float* %arrayidx, align 4 + %arrayidx3 = getelementptr inbounds i32, i32* %b, i64 %indvars.iv + %1 = load i32, i32* %arrayidx3, align 4 + %idxprom4 = sext i32 %1 to i64 + %arrayidx5 = getelementptr inbounds float, float* %a, i64 %idxprom4 + store float %0, float* %arrayidx5, align 4 + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %n + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !0 +} + +define void @scatter_inv_nxv4i32(i32* noalias nocapture %inv, i32* noalias nocapture readonly %b, i64 %n) { +; CHECK-LABEL: @scatter_inv_nxv4i32 +; CHECK: vector.ph: +; CHECK: %[[INS:.*]] = insertelement poison, i32* %inv, i32 0 +; CHECK: %[[PTRSPLAT:.*]] = shufflevector %[[INS]], poison, zeroinitializer +; CHECK: vector.body: +; CHECK: %[[VALS:.*]] = load , * %5, align 4 +; CHECK: %[[MASK:.*]] = icmp ne %[[VALS]], +; CHECK: call void @llvm.masked.scatter.nxv4i32.nxv4p0i32({{.*}}, %[[PTRSPLAT]], i32 4, %[[MASK]]) +entry: + %cmp5 = icmp sgt i64 %n, 0 + br i1 %cmp5, label %for.body.preheader, label %for.cond.cleanup + +for.body.preheader: ; preds = %entry + br label %for.body + +for.cond.cleanup.loopexit: ; preds = %for.inc + br label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.cond.cleanup.loopexit, %entry + ret void + +for.body: ; preds = %for.body.preheader, %for.inc + %indvars.iv = phi i64 [ %indvars.iv.next, %for.inc ], [ 0, %for.body.preheader ] + %arrayidx = getelementptr inbounds i32, i32* %b, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %tobool.not = icmp eq i32 %0, 0 + br i1 %tobool.not, label %for.inc, label %if.then + +if.then: ; preds = %for.body + store i32 3, i32* %inv, align 4 + br label %for.inc + +for.inc: ; preds = %for.body, %if.then + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %n + br i1 %exitcond.not, label %for.cond.cleanup.loopexit, label %for.body, !llvm.loop !0 +} + +define void @gather_inv_nxv4i32(i32* noalias nocapture %a, i32* noalias nocapture readonly %inv, i64 %n) { +; CHECK-LABEL: @gather_inv_nxv4i32 +; CHECK: vector.ph: +; CHECK: %[[INS:.*]] = insertelement poison, i32* %inv, i32 0 +; CHECK: %[[PTRSPLAT:.*]] = shufflevector %[[INS]], poison, zeroinitializer +; CHECK: vector.body: +; CHECK: %[[VALS:.*]] = load , * %5, align 4 +; CHECK: %[[MASK:.*]] = icmp sgt %[[VALS]], +; CHECK: %{{.*}} = call @llvm.masked.gather.nxv4i32.nxv4p0i32( %[[PTRSPLAT]], i32 4, %[[MASK]] +entry: + %cmp11 = icmp sgt i64 %n, 0 + br i1 %cmp11, label %for.body, label %for.cond.cleanup + +for.cond.cleanup: ; preds = %for.inc, %entry + ret void + +for.body: ; preds = %entry, %for.inc + %indvars.iv = phi i64 [ %indvars.iv.next, %for.inc ], [ 0, %entry ] + %arrayidx = getelementptr inbounds i32, i32* %a, i64 %indvars.iv + %0 = load i32, i32* %arrayidx, align 4 + %cmp2 = icmp sgt i32 %0, 3 + br i1 %cmp2, label %if.then, label %for.inc + +if.then: ; preds = %for.body + %1 = load i32, i32* %inv, align 4 + store i32 %1, i32* %arrayidx, align 4 + br label %for.inc + +for.inc: ; preds = %for.body, %if.then + %indvars.iv.next = add nuw nsw i64 %indvars.iv, 1 + %exitcond.not = icmp eq i64 %indvars.iv.next, %n + br i1 %exitcond.not, label %for.cond.cleanup, label %for.body, !llvm.loop !0 +} + +!0 = distinct !{!0, !1, !2, !3, !4, !5} +!1 = !{!"llvm.loop.mustprogress"} +!2 = !{!"llvm.loop.vectorize.width", i32 4} +!3 = !{!"llvm.loop.vectorize.scalable.enable", i1 true} +!4 = !{!"llvm.loop.interleave.count", i32 1} +!5 = !{!"llvm.loop.vectorize.enable", i1 true}