diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1287,6 +1287,7 @@ if (auto *II = dyn_cast(I)) { switch (II->getIntrinsicID()) { case Intrinsic::fma: + case Intrinsic::vp_fma: return Operand == 0 || Operand == 1; // FIXME: Our patterns can only match vx/vf instructions when the splat // it on the RHS, because TableGen doesn't recognize our VP operations diff --git a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll --- a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll +++ b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll @@ -3726,3 +3726,93 @@ for.cond.cleanup: ; preds = %vector.body ret void } + +declare <4 x float> @llvm.vp.fma.v4f32(<4 x float>, <4 x float>, <4 x float>, <4 x i1>, i32) + +define void @sink_splat_vp_fma(float* noalias nocapture %a, float* nocapture readonly %b, float %x, <4 x i1> %m, i32 zeroext %vl) { +; CHECK-LABEL: sink_splat_vp_fma: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: fmv.w.x ft0, a2 +; CHECK-NEXT: li a2, 1024 +; CHECK-NEXT: .LBB65_1: # %vector.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vle32.v v9, (a1) +; CHECK-NEXT: vsetvli zero, a3, e32, m1, tu, mu +; CHECK-NEXT: vfmadd.vf v8, ft0, v9, v0.t +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vse32.v v8, (a0) +; CHECK-NEXT: addi a2, a2, -4 +; CHECK-NEXT: addi a1, a1, 16 +; CHECK-NEXT: addi a0, a0, 16 +; CHECK-NEXT: bnez a2, .LBB65_1 +; CHECK-NEXT: # %bb.2: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %broadcast.splatinsert = insertelement <4 x float> poison, float %x, i32 0 + %broadcast.splat = shufflevector <4 x float> %broadcast.splatinsert, <4 x float> poison, <4 x i32> zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = getelementptr inbounds float, float* %a, i64 %index + %1 = bitcast float* %0 to <4 x float>* + %wide.load = load <4 x float>, <4 x float>* %1, align 4 + %2 = getelementptr inbounds float, float* %b, i64 %index + %3 = bitcast float* %2 to <4 x float>* + %wide.load12 = load <4 x float>, <4 x float>* %3, align 4 + %4 = call <4 x float> @llvm.vp.fma.v4f32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x float> %wide.load12, <4 x i1> %m, i32 %vl) + %5 = bitcast float* %0 to <4 x float>* + store <4 x float> %4, <4 x float>* %5, align 4 + %index.next = add nuw i64 %index, 4 + %6 = icmp eq i64 %index.next, 1024 + br i1 %6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +} + +define void @sink_splat_vp_fma_commute(float* noalias nocapture %a, float* nocapture readonly %b, float %x, <4 x i1> %m, i32 zeroext %vl) { +; CHECK-LABEL: sink_splat_vp_fma_commute: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: fmv.w.x ft0, a2 +; CHECK-NEXT: li a2, 1024 +; CHECK-NEXT: .LBB66_1: # %vector.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vle32.v v8, (a0) +; CHECK-NEXT: vle32.v v9, (a1) +; CHECK-NEXT: vsetvli zero, a3, e32, m1, tu, mu +; CHECK-NEXT: vfmadd.vf v8, ft0, v9, v0.t +; CHECK-NEXT: vsetivli zero, 4, e32, m1, ta, mu +; CHECK-NEXT: vse32.v v8, (a0) +; CHECK-NEXT: addi a2, a2, -4 +; CHECK-NEXT: addi a1, a1, 16 +; CHECK-NEXT: addi a0, a0, 16 +; CHECK-NEXT: bnez a2, .LBB66_1 +; CHECK-NEXT: # %bb.2: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %broadcast.splatinsert = insertelement <4 x float> poison, float %x, i32 0 + %broadcast.splat = shufflevector <4 x float> %broadcast.splatinsert, <4 x float> poison, <4 x i32> zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = getelementptr inbounds float, float* %a, i64 %index + %1 = bitcast float* %0 to <4 x float>* + %wide.load = load <4 x float>, <4 x float>* %1, align 4 + %2 = getelementptr inbounds float, float* %b, i64 %index + %3 = bitcast float* %2 to <4 x float>* + %wide.load12 = load <4 x float>, <4 x float>* %3, align 4 + %4 = call <4 x float> @llvm.vp.fma.v4f32(<4 x float> %broadcast.splat, <4 x float> %wide.load, <4 x float> %wide.load12, <4 x i1> %m, i32 %vl) + %5 = bitcast float* %0 to <4 x float>* + store <4 x float> %4, <4 x float>* %5, align 4 + %index.next = add nuw i64 %index, 4 + %6 = icmp eq i64 %index.next, 1024 + br i1 %6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +}