Index: llvm/lib/CodeGen/ExpandVectorPredication.cpp =================================================================== --- llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -69,9 +69,17 @@ #undef VPINTERNAL_VPLEGAL_CASES +// VP strided load/store -> VP gather/scatter transformation +static cl::opt StridedOnlyTransformation( + "expandvp-strided-only", cl::init(false), cl::Hidden, + cl::desc("If true, ignore TargetTransformInfo and always expand " + "experimental.vp.strided.load/store intrinsics to " + "vp.gather/scatter ones (Used in testing).")); + // Whether any override options are set. static bool anyExpandVPOverridesSet() { - return !EVLTransformOverride.empty() || !MaskTransformOverride.empty(); + return !EVLTransformOverride.empty() || !MaskTransformOverride.empty() || + StridedOnlyTransformation; } #define DEBUG_TYPE "expandvp" @@ -182,6 +190,11 @@ Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, VPIntrinsic &VPI); + /// \brief Lower this VP strided either to a VP gather/scatter, if legal for + /// the target, or to an unpredicated masked gather/scatter. + Value *expandPredicationInStridedLoadStore(IRBuilder<> &Builder, + VPIntrinsic &PI); + /// \brief Query TTI and expand the vector predication in \p P accordingly. Value *expandPredication(VPIntrinsic &PI); @@ -551,6 +564,9 @@ case Intrinsic::vp_gather: case Intrinsic::vp_scatter: return expandPredicationInMemoryIntrinsic(Builder, VPI); + case Intrinsic::experimental_vp_strided_load: + case Intrinsic::experimental_vp_strided_store: + return expandPredicationInStridedLoadStore(Builder, VPI); } return &VPI; @@ -598,11 +614,91 @@ // Overrides set - we are in testing, the following does not need to be // efficient. + + // If StridedOnlyTransformation is true, EVLTransformOverride and + // MaskTransformOverride are ignored + if (StridedOnlyTransformation) { + VPStrat.EVLParamStrategy = VPLegalization::Legal; + switch (VPI.getIntrinsicID()) { + default: + VPStrat.OpStrategy = VPLegalization::Legal; + break; + case Intrinsic::experimental_vp_strided_load: + case Intrinsic::experimental_vp_strided_store: + VPStrat.OpStrategy = VPLegalization::Convert; + break; + } + return VPStrat; + } + VPStrat.EVLParamStrategy = parseOverrideOption(EVLTransformOverride); VPStrat.OpStrategy = parseOverrideOption(MaskTransformOverride); return VPStrat; } +Value * +CachingVPExpander::expandPredicationInStridedLoadStore(IRBuilder<> &Builder, + VPIntrinsic &PI) { + auto CreateGatherScatterPointer = [&](VectorType *VT, + Value *Stride) -> Value * { + Value *MemPtr = PI.getMemoryPointerParam(); + ElementCount EC = VT->getElementCount(); + // Bitcast the MemPtr to a pointer to i8 + Value *Cast = Builder.CreateBitCast(MemPtr, Builder.getInt8PtrTy()); + // Create a vector of pointers %addrs in the form: + // %addrs = <%addr, %addr + %stride, %addr + 2 * %stride, ...> + Value *Splat = Builder.CreateVectorSplat(EC, Stride); + Value *StepVector = + Builder.CreateStepVector(VectorType::get(Builder.getInt64Ty(), EC)); + Value *Indexes = + Builder.CreateBinOp(Instruction::BinaryOps::Mul, StepVector, Splat); + Value *GEP = Builder.CreateGEP(Builder.getInt8Ty(), Cast, Indexes); + // Bitcast the generated array of pointers to i8 to pointers to the original + // element type + VectorType *DstType = + VectorType::get(VT->getElementType()->getPointerTo(), EC); + return Builder.CreateBitCast(GEP, DstType); + }; + + unsigned IntrNo = Intrinsic::not_intrinsic; + SmallVector Types; + SmallVector Args; + switch (PI.getIntrinsicID()) { + default: + llvm_unreachable("Not a vp.strided.* intrinsic"); + case Intrinsic::experimental_vp_strided_load: { + IntrNo = Intrinsic::vp_gather; + VectorType *VT = cast(PI.getType()); + Value *VecOfPtrs = CreateGatherScatterPointer(VT, PI.getOperand(1)); + Types = {VT, VecOfPtrs->getType()}; + Args = {VecOfPtrs, PI.getMaskParam(), PI.getVectorLengthParam()}; + break; + } + case Intrinsic::experimental_vp_strided_store: { + IntrNo = Intrinsic::vp_scatter; + VectorType *VT = cast(PI.getMemoryDataParam()->getType()); + Value *VecOfPtrs = CreateGatherScatterPointer(VT, PI.getOperand(2)); + Types = {VT, VecOfPtrs->getType()}; + Args = {PI.getMemoryDataParam(), VecOfPtrs, PI.getMaskParam(), + PI.getVectorLengthParam()}; + break; + } + } + + auto *Intr = Builder.CreateIntrinsic(IntrNo, Types, Args); + replaceOperation(*Intr, PI); + + VPIntrinsic *VPIntr = cast(Intr); + auto VPStrat = getVPLegalizationStrategy(*VPIntr); + sanitizeStrategy(*Intr, VPStrat); + if (!VPStrat.shouldDoNothing()) { + Builder.SetInsertPoint(VPIntr); + return expandPredicationInMemoryIntrinsic(Builder, *VPIntr); + } + + return Intr; +} + /// \brief Expand llvm.vp.* intrinsics as requested by \p TTI. bool CachingVPExpander::expandVectorPredication() { SmallVector Worklist; @@ -639,7 +735,7 @@ } Job.Strategy.EVLParamStrategy = VPLegalization::Legal; - // Replace with a non-predicated operation. + // Replace with a non-predicated operation or another legal VP operation. switch (Job.Strategy.OpStrategy) { case VPLegalization::Legal: break; Index: llvm/test/CodeGen/Generic/expand-vp-strided-accesses.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/Generic/expand-vp-strided-accesses.ll @@ -0,0 +1,92 @@ +; RUN: opt --expandvp --expandvp-strided-only -S < %s | FileCheck %s --check-prefixes=CHECK,VP_TO_VP +; RUN: opt --expandvp -S < %s | FileCheck %s --check-prefixes=CHECK,VP_TO_NOT_VP + +; Fixed vectors +define <4 x float> @strided_load_v4f32(float* %ptr, i64 %stride, <4 x i1> %mask, i32 zeroext %evl) { +; CHECK-LABEL: @strided_load_v4f32( +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EVL:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <4 x i32> , [[DOTSPLAT]] +; CHECK-NEXT: [[TMP2:%.*]] = and <4 x i1> [[TMP1]], [[MASK:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[PTR:%.*]] to i8* +; CHECK-NEXT: [[DOTSPLATINSERT1:%.*]] = insertelement <4 x i64> poison, i64 [[STRIDE:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT2:%.*]] = shufflevector <4 x i64> [[DOTSPLATINSERT1]], <4 x i64> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = mul <4 x i64> , [[DOTSPLAT2]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, i8* [[TMP3]], <4 x i64> [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <4 x i8*> [[TMP5]] to <4 x float*> +; +; VP_TO_VP: [[TMP7:%.*]] = call <4 x float> @llvm.vp.gather.v4f32.v4p0f32(<4 x float*> [[TMP6]], <4 x i1> [[TMP2]], i32 4) +; VP_TO_NOT_VP: [[TMP7:%.*]] = call <4 x float> @llvm.masked.gather.v4f32.v4p0f32(<4 x float*> [[TMP6]], i32 4, <4 x i1> [[TMP2]], <4 x float> undef) +; + %v = call <4 x float> @llvm.experimental.vp.strided.load.v4f32.p0f32.i64(float* %ptr, i64 %stride, <4 x i1> %mask, i32 %evl) + ret <4 x float> %v +} + +define void @strided_store_v4f32(<4 x float> %v, float *%ptr, i64 %stride, <4 x i1> %mask, i32 zeroext %evl) { +; CHECK-LABEL: @strided_store_v4f32( +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <4 x i32> poison, i32 [[EVL:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <4 x i32> [[DOTSPLATINSERT]], <4 x i32> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP1:%.*]] = icmp ult <4 x i32> , [[DOTSPLAT]] +; CHECK-NEXT: [[TMP2:%.*]] = and <4 x i1> [[TMP1]], [[MASK:%.*]] +; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[PTR:%.*]] to i8* +; CHECK-NEXT: [[DOTSPLATINSERT1:%.*]] = insertelement <4 x i64> poison, i64 [[STRIDE:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT2:%.*]] = shufflevector <4 x i64> [[DOTSPLATINSERT1]], <4 x i64> poison, <4 x i32> zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = mul <4 x i64> , [[DOTSPLAT2]] +; CHECK-NEXT: [[TMP5:%.*]] = getelementptr i8, i8* [[TMP3]], <4 x i64> [[TMP4]] +; CHECK-NEXT: [[TMP6:%.*]] = bitcast <4 x i8*> [[TMP5]] to <4 x float*> +; +; VP_TO_VP: call void @llvm.vp.scatter.v4f32.v4p0f32(<4 x float> [[V:%.*]], <4 x float*> [[TMP6]], <4 x i1> [[TMP2]], i32 4) +; VP_TO_NOT_VP: call void @llvm.masked.scatter.v4f32.v4p0f32(<4 x float> [[V:%.*]], <4 x float*> [[TMP6]], i32 4, <4 x i1> [[TMP2]]) +; + call void @llvm.experimental.vp.strided.store.v4f32.p0f32.i64(<4 x float> %v, float* %ptr, i64 %stride, <4 x i1> %mask, i32 %evl) + ret void +} + +; Scalable vectors +define @strided_load_nxv2f32(float* %ptr, i64 %stride, %mask, i32 zeroext %evl) { +; CHECK-LABEL: @strided_load_nxv2f32( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.get.active.lane.mask.nxv2i1.i32(i32 0, i32 [[EVL:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = and [[TMP1]], [[MASK:%.*]] +; CHECK-NEXT: [[VSCALE:%.*]] = call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[PTR:%.*]] to i8* +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[STRIDE:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.experimental.stepvector.nxv2i64() +; CHECK-NEXT: [[TMP5:%.*]] = mul [[TMP4]], [[DOTSPLAT]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, i8* [[TMP3]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = bitcast [[TMP6]] to +; +; VP_TO_VP: [[TMP8:%.*]] = call @llvm.vp.gather.nxv2f32.nxv2p0f32( [[TMP7]], [[TMP2]], i32 [[SCALABLE_SIZE]]) +; VP_TO_NOT_VP: [[TMP8:%.*]] = call @llvm.masked.gather.nxv2f32.nxv2p0f32( [[TMP7]], i32 4, [[TMP2]], undef) +; + %v = call @llvm.experimental.vp.strided.load.nxv2f32.p0f32.i64(float* %ptr, i64 %stride, %mask, i32 %evl) + ret %v +} + +define void @strided_store_nxv2f32( %v, float *%ptr, i64 %stride, %mask, i32 zeroext %evl) { +; CHECK-LABEL: @strided_store_nxv2f32( +; CHECK-NEXT: [[TMP1:%.*]] = call @llvm.get.active.lane.mask.nxv2i1.i32(i32 0, i32 [[EVL:%.*]]) +; CHECK-NEXT: [[TMP2:%.*]] = and [[TMP1]], [[MASK:%.*]] +; CHECK-NEXT: [[VSCALE:%.*]] = call i32 @llvm.vscale.i32() +; CHECK-NEXT: [[SCALABLE_SIZE:%.*]] = mul nuw i32 [[VSCALE]], 2 +; CHECK-NEXT: [[TMP3:%.*]] = bitcast float* [[PTR:%.*]] to i8* +; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement poison, i64 [[STRIDE:%.*]], i32 0 +; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector [[DOTSPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[TMP4:%.*]] = call @llvm.experimental.stepvector.nxv2i64() +; CHECK-NEXT: [[TMP5:%.*]] = mul [[TMP4]], [[DOTSPLAT]] +; CHECK-NEXT: [[TMP6:%.*]] = getelementptr i8, i8* [[TMP3]], [[TMP5]] +; CHECK-NEXT: [[TMP7:%.*]] = bitcast [[TMP6]] to +; +; VP_TO_VP: call void @llvm.vp.scatter.nxv2f32.nxv2p0f32( [[V:%.*]], [[TMP7]], [[TMP2]], i32 [[SCALABLE_SIZE]]) +; VP_TO_NOT_VP: call void @llvm.masked.scatter.nxv2f32.nxv2p0f32( [[V:%.*]], [[TMP7]], i32 4, [[TMP2]]) +; + call void @llvm.experimental.vp.strided.store.nxv2f32.p0f32.i64( %v, float* %ptr, i64 %stride, %mask, i32 %evl) + ret void +} + +declare <4 x float> @llvm.experimental.vp.strided.load.v4f32.p0f32.i64(float*, i64, <4 x i1>, i32) +declare void @llvm.experimental.vp.strided.store.v4f32.p0f32.i64(<4 x float>, float*, i64, <4 x i1>, i32) + +declare @llvm.experimental.vp.strided.load.nxv2f32.p0f32.i64(float*, i64, , i32) +declare void @llvm.experimental.vp.strided.store.nxv2f32.p0f32.i64(, float*, i64, , i32)