diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.h +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.h @@ -120,6 +120,11 @@ Optional instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const; + Optional simplifyDemandedVectorEltsIntrinsic( + InstCombiner &IC, IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, + APInt &UndefElts2, APInt &UndefElts3, + std::function + SimplifyAndSetOp) const; /// \name Scalar TTI Implementations /// @{ diff --git a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp --- a/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp +++ b/llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp @@ -248,6 +248,48 @@ return None; } +Optional ARMTTIImpl::simplifyDemandedVectorEltsIntrinsic( + InstCombiner &IC, IntrinsicInst &II, APInt OrigDemandedElts, + APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3, + std::function + SimplifyAndSetOp) const { + + // Compute the demanded bits for a narrowing MVE intrinsic. The TopOpc is the + // opcode specifying a Top/Bottom instruction, which can change between + // instructions. + auto SimplifyNarrowInstrTopBottom =[&](unsigned TopOpc) { + unsigned NumElts = cast(II.getType())->getNumElements(); + unsigned IsTop = cast(II.getOperand(TopOpc))->getZExtValue(); + + // The only odd/even lanes of operand 0 will only be demanded depending + // on whether this is a top/bottom instruction. + APInt DemandedElts = + APInt::getSplat(NumElts, IsTop ? APInt::getLowBitsSet(2, 1) + : APInt::getHighBitsSet(2, 1)); + SimplifyAndSetOp(&II, 0, OrigDemandedElts & DemandedElts, UndefElts); + // The other lanes will be defined from the inserted elements. + UndefElts &= APInt::getSplat(NumElts, !IsTop ? APInt::getLowBitsSet(2, 1) + : APInt::getHighBitsSet(2, 1)); + return None; + }; + + switch (II.getIntrinsicID()) { + default: + break; + case Intrinsic::arm_mve_vcvt_narrow: + SimplifyNarrowInstrTopBottom(2); + break; + case Intrinsic::arm_mve_vqmovn: + SimplifyNarrowInstrTopBottom(4); + break; + case Intrinsic::arm_mve_vshrn: + SimplifyNarrowInstrTopBottom(7); + break; + } + + return None; +} + InstructionCost ARMTTIImpl::getIntImmCost(const APInt &Imm, Type *Ty, TTI::TargetCostKind CostKind) { assert(Ty->isIntegerTy()); diff --git a/llvm/test/Transforms/InstCombine/ARM/mve-narrow.ll b/llvm/test/Transforms/InstCombine/ARM/mve-narrow.ll --- a/llvm/test/Transforms/InstCombine/ARM/mve-narrow.ll +++ b/llvm/test/Transforms/InstCombine/ARM/mve-narrow.ll @@ -7,7 +7,7 @@ define <8 x i16> @test_shrn_v8i16_t1(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_t1( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], +; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -18,7 +18,7 @@ define <8 x i16> @test_shrn_v8i16_t2(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_t2( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], +; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -29,7 +29,7 @@ define <8 x i16> @test_shrn_v8i16_b1(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_b1( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], +; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -40,7 +40,7 @@ define <8 x i16> @test_shrn_v8i16_b2(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_b2( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], +; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -51,8 +51,7 @@ define <8 x i16> @test_shrn_v8i16_bt(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_bt( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) +; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> poison, <4 x i32> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[Y]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -64,8 +63,7 @@ define <8 x i16> @test_shrn_v8i16_tb(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_shrn_v8i16_tb( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) +; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> poison, <4 x i32> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vshrn.v8i16.v4i32(<8 x i16> [[Y]], <4 x i32> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -105,8 +103,7 @@ define <16 x i8> @test_shrn_v16i8_bt(<16 x i8> %a, <16 x i8> %b, <8 x i16> %c, <8 x i16> %d) { ; CHECK-LABEL: @test_shrn_v16i8_bt( -; CHECK-NEXT: [[X:%.*]] = add <16 x i8> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <16 x i8> @llvm.arm.mve.vshrn.v16i8.v8i16(<16 x i8> [[X]], <8 x i16> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) +; CHECK-NEXT: [[Y:%.*]] = call <16 x i8> @llvm.arm.mve.vshrn.v16i8.v8i16(<16 x i8> poison, <8 x i16> [[C:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 0) ; CHECK-NEXT: [[Z:%.*]] = call <16 x i8> @llvm.arm.mve.vshrn.v16i8.v8i16(<16 x i8> [[Y]], <8 x i16> [[D:%.*]], i32 16, i32 0, i32 0, i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <16 x i8> [[Z]] ; @@ -171,8 +168,7 @@ define <8 x i16> @test_qmovn_v8i16_bt(<8 x i16> %a, <8 x i16> %b, <4 x i32> %c, <4 x i32> %d) { ; CHECK-LABEL: @test_qmovn_v8i16_bt( -; CHECK-NEXT: [[X:%.*]] = add <8 x i16> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vqmovn.v8i16.v4i32(<8 x i16> [[X]], <4 x i32> [[C:%.*]], i32 0, i32 0, i32 0) +; CHECK-NEXT: [[Y:%.*]] = call <8 x i16> @llvm.arm.mve.vqmovn.v8i16.v4i32(<8 x i16> poison, <4 x i32> [[C:%.*]], i32 0, i32 0, i32 0) ; CHECK-NEXT: [[Z:%.*]] = call <8 x i16> @llvm.arm.mve.vqmovn.v8i16.v4i32(<8 x i16> [[Y]], <4 x i32> [[D:%.*]], i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <8 x i16> [[Z]] ; @@ -184,8 +180,7 @@ define <16 x i8> @test_qmovn_v16i8_bt(<16 x i8> %a, <16 x i8> %b, <8 x i16> %c, <8 x i16> %d) { ; CHECK-LABEL: @test_qmovn_v16i8_bt( -; CHECK-NEXT: [[X:%.*]] = add <16 x i8> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <16 x i8> @llvm.arm.mve.vqmovn.v16i8.v8i16(<16 x i8> [[X]], <8 x i16> [[C:%.*]], i32 0, i32 0, i32 0) +; CHECK-NEXT: [[Y:%.*]] = call <16 x i8> @llvm.arm.mve.vqmovn.v16i8.v8i16(<16 x i8> poison, <8 x i16> [[C:%.*]], i32 0, i32 0, i32 0) ; CHECK-NEXT: [[Z:%.*]] = call <16 x i8> @llvm.arm.mve.vqmovn.v16i8.v8i16(<16 x i8> [[Y]], <8 x i16> [[D:%.*]], i32 0, i32 0, i32 1) ; CHECK-NEXT: ret <16 x i8> [[Z]] ; @@ -223,8 +218,7 @@ define <8 x half> @test_cvtn_v8i16_bt(<8 x half> %a, <8 x half> %b, <4 x float> %c, <4 x float> %d) { ; CHECK-LABEL: @test_cvtn_v8i16_bt( -; CHECK-NEXT: [[X:%.*]] = fadd <8 x half> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[Y:%.*]] = call <8 x half> @llvm.arm.mve.vcvt.narrow(<8 x half> [[X]], <4 x float> [[C:%.*]], i32 0) +; CHECK-NEXT: [[Y:%.*]] = call <8 x half> @llvm.arm.mve.vcvt.narrow(<8 x half> poison, <4 x float> [[C:%.*]], i32 0) ; CHECK-NEXT: [[Z:%.*]] = call <8 x half> @llvm.arm.mve.vcvt.narrow(<8 x half> [[Y]], <4 x float> [[D:%.*]], i32 1) ; CHECK-NEXT: ret <8 x half> [[Z]] ;