diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp --- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp +++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp @@ -725,6 +725,22 @@ return IC.replaceInstUsesWith(II, FMLA); } +static bool isAllActivePredicate(Value *Pred) { + // Look through convert.from.svbool(convert.to.svbool(...) chain. + Value *UncastedPred; + if (match(Pred, m_Intrinsic( + m_Intrinsic( + m_Value(UncastedPred))))) + // If the predicate has the same or less lanes than the uncasted + // predicate then we know the casting has no effect. + if (cast(Pred->getType())->getMinNumElements() <= + cast(UncastedPred->getType())->getMinNumElements()) + Pred = UncastedPred; + + return match(Pred, m_Intrinsic( + m_ConstantInt())); +} + static Optional instCombineSVELD1(InstCombiner &IC, IntrinsicInst &II, const DataLayout &DL) { IRBuilder<> Builder(II.getContext()); @@ -735,8 +751,7 @@ Type *VecTy = II.getType(); Value *VecPtr = Builder.CreateBitCast(PtrOp, VecTy->getPointerTo()); - if (match(Pred, m_Intrinsic( - m_ConstantInt()))) { + if (isAllActivePredicate(Pred)) { LoadInst *Load = Builder.CreateLoad(VecTy, VecPtr); return IC.replaceInstUsesWith(II, Load); } @@ -758,8 +773,7 @@ Value *VecPtr = Builder.CreateBitCast(PtrOp, VecOp->getType()->getPointerTo()); - if (match(Pred, m_Intrinsic( - m_ConstantInt()))) { + if (isAllActivePredicate(Pred)) { Builder.CreateStore(VecOp, VecPtr); return IC.eraseInstFromFunction(II); } diff --git a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-loadstore.ll b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-loadstore.ll --- a/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-loadstore.ll +++ b/llvm/test/Transforms/InstCombine/AArch64/sve-intrinsic-loadstore.ll @@ -14,6 +14,19 @@ ret %2 } +define @combine_ld1_casted_predicate(i32* %ptr) #0 { +; CHECK-LABEL: @combine_ld1_casted_predicate( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[PTR:%.*]] to * +; CHECK-NEXT: [[TMP2:%.*]] = load , * [[TMP1]], align 16 +; CHECK-NEXT: ret [[TMP2]] +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %2) + %4 = call @llvm.aarch64.sve.ld1.nxv4i32( %3, i32* %ptr) + ret %4 +} + define @combine_ld1_masked(i32* %ptr) #0 { ; CHECK-LABEL: @combine_ld1_masked( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 16) @@ -26,6 +39,22 @@ ret %2 } +define @combine_ld1_masked_casted_predicate(i16* %ptr) #0 { +; CHECK-LABEL: @combine_ld1_masked_casted_predicate( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[PTR:%.*]] to * +; CHECK-NEXT: [[TMP5:%.*]] = call @llvm.masked.load.nxv8i16.p0nxv8i16(* [[TMP4]], i32 1, [[TMP3]], zeroinitializer) +; CHECK-NEXT: ret [[TMP5]] +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %2) + %4 = call @llvm.aarch64.sve.ld1.nxv8i16( %3, i16* %ptr) + ret %4 +} + define void @combine_st1( %vec, i32* %ptr) #0 { ; CHECK-LABEL: @combine_st1( ; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[PTR:%.*]] to * @@ -37,6 +66,19 @@ ret void } +define void @combine_st1_casted_predicate( %vec, i32* %ptr) #0 { +; CHECK-LABEL: @combine_st1_casted_predicate( +; CHECK-NEXT: [[TMP1:%.*]] = bitcast i32* [[PTR:%.*]] to * +; CHECK-NEXT: store [[VEC:%.*]], * [[TMP1]], align 16 +; CHECK-NEXT: ret void +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv8i1(i32 31) + %2 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv8i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv4i1( %2) + call void @llvm.aarch64.sve.st1.nxv4i32( %vec, %3, i32* %ptr) + ret void +} + define void @combine_st1_masked( %vec, i32* %ptr) #0 { ; CHECK-LABEL: @combine_st1_masked( ; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 16) @@ -49,10 +91,36 @@ ret void } -declare void @llvm.aarch64.sve.st1.nxv4i32(, , i32*) -declare @llvm.aarch64.sve.ptrue.nxv4i1(i32) -declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) +define void @combine_st1_masked_casted_predicate( %vec, i16* %ptr) #0 { +; CHECK-LABEL: @combine_st1_masked_casted_predicate( +; CHECK-NEXT: [[TMP1:%.*]] = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) +; CHECK-NEXT: [[TMP2:%.*]] = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( [[TMP1]]) +; CHECK-NEXT: [[TMP3:%.*]] = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( [[TMP2]]) +; CHECK-NEXT: [[TMP4:%.*]] = bitcast i16* [[PTR:%.*]] to * +; CHECK-NEXT: call void @llvm.masked.store.nxv8i16.p0nxv8i16( [[VEC:%.*]], * [[TMP4]], i32 1, [[TMP3]]) +; CHECK-NEXT: ret void +; + %1 = tail call @llvm.aarch64.sve.ptrue.nxv4i1(i32 31) + %2 = tail call @llvm.aarch64.sve.convert.to.svbool.nxv4i1( %1) + %3 = tail call @llvm.aarch64.sve.convert.from.svbool.nxv8i1( %2) + call void @llvm.aarch64.sve.st1.nxv8i16( %vec, %3, i16* %ptr) + ret void +} + declare @llvm.aarch64.sve.convert.from.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.from.svbool.nxv8i1() + +declare @llvm.aarch64.sve.convert.to.svbool.nxv4i1() +declare @llvm.aarch64.sve.convert.to.svbool.nxv8i1() + declare @llvm.aarch64.sve.ld1.nxv4i32(, i32*) +declare @llvm.aarch64.sve.ld1.nxv8i16(, i16*) + +declare @llvm.aarch64.sve.ptrue.nxv4i1(i32) +declare @llvm.aarch64.sve.ptrue.nxv8i1(i32) +declare @llvm.aarch64.sve.ptrue.nxv16i1(i32) + +declare void @llvm.aarch64.sve.st1.nxv4i32(, , i32*) +declare void @llvm.aarch64.sve.st1.nxv8i16(, , i16*) attributes #0 = { "target-features"="+sve" }