diff --git a/llvm/lib/CodeGen/ExpandVectorPredication.cpp b/llvm/lib/CodeGen/ExpandVectorPredication.cpp --- a/llvm/lib/CodeGen/ExpandVectorPredication.cpp +++ b/llvm/lib/CodeGen/ExpandVectorPredication.cpp @@ -179,6 +179,10 @@ Value *expandPredicationInMemoryIntrinsic(IRBuilder<> &Builder, VPIntrinsic &VPI); + /// \brief Lower this VP comparison to a call to an unpredicated comparison. + Value *expandPredicationInComparison(IRBuilder<> &Builder, + VPCmpIntrinsic &PI); + /// \brief Query TTI and expand the vector predication in \p P accordingly. Value *expandPredication(VPIntrinsic &PI); @@ -462,6 +466,24 @@ return NewMemoryInst; } +Value *CachingVPExpander::expandPredicationInComparison(IRBuilder<> &Builder, + VPCmpIntrinsic &VPI) { + assert((maySpeculateLanes(VPI) || VPI.canIgnoreVectorLengthParam()) && + "Implicitly dropping %evl in non-speculatable operator!"); + + auto OC = *VPI.getFunctionalOpcode(); + assert(OC == Instruction::ICmp || OC == Instruction::FCmp); + + Value *Op0 = VPI.getOperand(0); + Value *Op1 = VPI.getOperand(1); + auto Pred = VPI.getPredicate(); + + auto *NewCmp = Builder.CreateCmp(Pred, Op0, Op1); + + replaceOperation(*NewCmp, VPI); + return NewCmp; +} + void CachingVPExpander::discardEVLParameter(VPIntrinsic &VPI) { LLVM_DEBUG(dbgs() << "Discard EVL parameter in " << VPI << "\n"); @@ -538,6 +560,9 @@ if (auto *VPRI = dyn_cast(&VPI)) return expandPredicationInReduction(Builder, *VPRI); + if (auto *VPCmp = dyn_cast(&VPI)) + return expandPredicationInComparison(Builder, *VPCmp); + switch (VPI.getIntrinsicID()) { default: break; diff --git a/llvm/test/CodeGen/Generic/expand-vp.ll b/llvm/test/CodeGen/Generic/expand-vp.ll --- a/llvm/test/CodeGen/Generic/expand-vp.ll +++ b/llvm/test/CodeGen/Generic/expand-vp.ll @@ -39,6 +39,9 @@ declare float @llvm.vp.reduce.fmax.v4f32(float, <4 x float>, <4 x i1>, i32) declare float @llvm.vp.reduce.fadd.v4f32(float, <4 x float>, <4 x i1>, i32) declare float @llvm.vp.reduce.fmul.v4f32(float, <4 x float>, <4 x i1>, i32) +; Comparisons +declare <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32>, <8 x i32>, metadata, <8 x i1>, i32) +declare <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float>, <8 x float>, metadata, <8 x i1>, i32) ; Fixed vector test function. define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { @@ -121,6 +124,14 @@ ret void } +define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x float> %f1, <8 x i1> %m, i32 %n) { + %r0 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"eq", <8 x i1> %m, i32 %n) + %r1 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"slt", <8 x i1> %m, i32 %n) + %r2 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"oeq", <8 x i1> %m, i32 %n) + %r3 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"ult", <8 x i1> %m, i32 %n) + ret void +} + ; All VP intrinsics have to be lowered into non-VP ops ; Convert %evl into %mask for non-speculatable VP intrinsics and emit the ; instruction+select idiom with a non-VP SIMD instruction. @@ -233,6 +244,15 @@ ; ALL-CONVERT-NEXT: %{{.+}} = call reassoc float @llvm.vector.reduce.fmul.v4f32(float %f, <4 x float> [[FMUL]]) ; ALL-CONVERT-NEXT: ret void +; Check that comparisons use the correct condition codes +; ALL-CONVERT: define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x float> %f1, <8 x i1> %m, i32 %n) { +; ALL-CONVERT-NEXT: %{{.+}} = icmp eq <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = icmp slt <8 x i32> %i0, %i1 +; ALL-CONVERT-NEXT: %{{.+}} = fcmp oeq <8 x float> %f0, %f1 +; ALL-CONVERT-NEXT: %{{.+}} = fcmp ult <8 x float> %f0, %f1 +; ALL-CONVERT-NEXT: ret void + + ; All legal - don't transform anything. ; LEGAL_LEGAL: define void @test_vp_int_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x i32> %i2, <8 x i32> %f3, <8 x i1> %m, i32 %n) { @@ -292,6 +312,13 @@ ; LEGAL_LEGAL-NEXT: %r9 = call reassoc float @llvm.vp.reduce.fmul.v4f32(float %f, <4 x float> %vf, <4 x i1> %m, i32 %n) ; LEGAL_LEGAL-NEXT: ret void +; LEGAL_LEGAL: define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x float> %f1, <8 x i1> %m, i32 %n) { +; LEGAL_LEGAL-NEXT: %r0 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"eq", <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r1 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"slt", <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r2 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"oeq", <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: %r3 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"ult", <8 x i1> %m, i32 %n) +; LEGAL_LEGAL-NEXT: ret void + ; Drop %evl where possible else fold %evl into %mask (%evl Discard, %mask Legal) ; ; There is no caching yet in the ExpandVectorPredication pass and the %evl @@ -372,6 +399,12 @@ ; DISCARD_LEGAL-NOT: %r9 = call reassoc float @llvm.vp.reduce.fmul.v4f32(float %f, <4 x float> %vf, <4 x i1> %m, i32 4) ; DISCARD_LEGAL: ret void +; DISCARD_LEGAL: define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x float> %f1, <8 x i1> %m, i32 %n) { +; DISCARD_LEGAL-NEXT: %r0 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"eq", <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r1 = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"slt", <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r2 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"oeq", <8 x i1> %m, i32 8) +; DISCARD_LEGAL-NEXT: %r3 = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"ult", <8 x i1> %m, i32 8) + ; Convert %evl into %mask everywhere (%evl Convert, %mask Legal) ; ; For the same reasons as in the (%evl Discard, %mask Legal) case only check that.. @@ -441,3 +474,15 @@ ; CONVERT_LEGAL-NOT: %{{.+}} = call float @llvm.vp.reduce.fmul.v4f32(float %f, <4 x float> %vf, <4 x i1> %m, i32 4) ; CONVERT_LEGAL-NOT: %{{.+}} = call reassoc float @llvm.vp.reduce.fmul.v4f32(float %f, <4 x float> %vf, <4 x i1> %m, i32 4) ; CONVERT_LEGAL: ret void + +; CONVERT_LEGAL: define void @test_vp_cmp_v8(<8 x i32> %i0, <8 x i32> %i1, <8 x float> %f0, <8 x float> %f1, <8 x i1> %m, i32 %n) { +; CONVERT_LEGAL-NEXT: [[NINS:%.+]] = insertelement <8 x i32> poison, i32 %n, i32 0 +; CONVERT_LEGAL-NEXT: [[NSPLAT:%.+]] = shufflevector <8 x i32> [[NINS]], <8 x i32> poison, <8 x i32> zeroinitializer +; CONVERT_LEGAL-NEXT: [[EVLM:%.+]] = icmp ult <8 x i32> , [[NSPLAT]] +; CONVERT_LEGAL-NEXT: [[NEWM:%.+]] = and <8 x i1> [[EVLM]], %m +; CONVERT_LEGAL-NEXT: %{{.+}} = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"eq", <8 x i1> [[NEWM]], i32 8) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"eq", <8 x i1> %m, i32 %n) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i1> @llvm.vp.icmp.v8i32(<8 x i32> %i0, <8 x i32> %i1, metadata !"slt", <8 x i1> %m, i32 %n +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"oeq", <8 x i1> %m, i32 %n) +; CONVERT_LEGAL-NOT: %{{.+}} = call <8 x i1> @llvm.vp.fcmp.v8f32(<8 x float> %f0, <8 x float> %f1, metadata !"ult", <8 x i1> %m, i32 %n) +; CONVERT_LEGAL: ret void