diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1232,7 +1232,7 @@ if (!I->getType()->isVectorTy() || !Subtarget.hasVInstructions()) return false; - auto IsSinker = [&](Instruction *I, int Operand) { + auto IsSinker = [](Instruction *I, int Operand) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Sub: @@ -1295,30 +1295,93 @@ } }; + auto IsMaskOperand = [](Instruction *I, int Operand) { + switch (I->getOpcode()) { + case Instruction::Call: + if (auto *II = dyn_cast(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::vp_fneg: + return Operand == 1; + // FIXME: Our patterns can only match vx/vf instructions when the splat + // it on the RHS, because TableGen doesn't recognize our VP operations + // as commutative. + case Intrinsic::vp_add: + case Intrinsic::vp_sub: + case Intrinsic::vp_mul: + case Intrinsic::vp_and: + case Intrinsic::vp_or: + case Intrinsic::vp_xor: + case Intrinsic::vp_fadd: + case Intrinsic::vp_fsub: + case Intrinsic::vp_fmul: + case Intrinsic::vp_fdiv: + case Intrinsic::vp_shl: + case Intrinsic::vp_lshr: + case Intrinsic::vp_ashr: + case Intrinsic::vp_udiv: + case Intrinsic::vp_sdiv: + case Intrinsic::vp_urem: + case Intrinsic::vp_srem: + return Operand == 2; + case Intrinsic::vp_fma: + return Operand == 3; + default: + return false; + } + } + return false; + default: + return false; + } + }; + for (auto OpIdx : enumerate(I->operands())) { - if (!IsSinker(I, OpIdx.index())) + Instruction *Op = dyn_cast(OpIdx.value().get()); + if (!Op) + continue; + + if (IsSinker(I, OpIdx.index())) { + // Make sure we are not already sinking this operand + if (any_of(Ops, [&](Use *U) { return U->get() == Op; })) + continue; + + // We are looking for a splat that can be sunk. + if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), + m_Undef(), m_ZeroMask()))) + continue; + + // All uses of the shuffle should be sunk to avoid duplicating it across + // gpr and vector registers + for (Use &U : Op->uses()) { + Instruction *Insn = cast(U.getUser()); + if (!IsSinker(Insn, U.getOperandNo())) + return false; + } + + Ops.push_back(&Op->getOperandUse(0)); + Ops.push_back(&OpIdx.value()); + continue; + } + + // Look for all ones mask operands that need to be sunk to enable isel to + // recognize that they are unmasked operations. + if (!isa(OpIdx.value().get()->getType()) || + !IsMaskOperand(I, OpIdx.index())) continue; - Instruction *Op = dyn_cast(OpIdx.value().get()); // Make sure we are not already sinking this operand - if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; })) + if (any_of(Ops, [&](Use *U) { return U->get() == Op; })) continue; - // We are looking for a splat that can be sunk. - if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), + // We are looking for a splat of 1 that can be sunk. + assert(Op->getType()->isIntOrIntVectorTy(1) && "Unexpected type"); + if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_One(), m_ZeroInt()), m_Undef(), m_ZeroMask()))) continue; - // All uses of the shuffle should be sunk to avoid duplicating it across gpr - // and vector registers - for (Use &U : Op->uses()) { - Instruction *Insn = cast(U.getUser()); - if (!IsSinker(Insn, U.getOperandNo())) - return false; - } - Ops.push_back(&Op->getOperandUse(0)); Ops.push_back(&OpIdx.value()); + continue; } return true; } diff --git a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll --- a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll +++ b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll @@ -3265,7 +3265,7 @@ ret void } -declare <4 x float> @llvm.vp.fmul.v4i32(<4 x float>, <4 x float>, <4 x i1>, i32) +declare <4 x float> @llvm.vp.fmul.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) define void @sink_splat_vp_fmul(float* nocapture %a, float %x, <4 x i1> %m, i32 zeroext %vl) { ; CHECK-LABEL: sink_splat_vp_fmul: @@ -3294,7 +3294,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fmul.v4i32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fmul.v4f32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3305,7 +3305,7 @@ ret void } -declare <4 x float> @llvm.vp.fdiv.v4i32(<4 x float>, <4 x float>, <4 x i1>, i32) +declare <4 x float> @llvm.vp.fdiv.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) define void @sink_splat_vp_fdiv(float* nocapture %a, float %x, <4 x i1> %m, i32 zeroext %vl) { ; CHECK-LABEL: sink_splat_vp_fdiv: @@ -3334,7 +3334,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fdiv.v4i32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fdiv.v4f32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3372,7 +3372,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fdiv.v4i32(<4 x float> %broadcast.splat, <4 x float> %wide.load, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fdiv.v4f32(<4 x float> %broadcast.splat, <4 x float> %wide.load, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3383,7 +3383,7 @@ ret void } -declare <4 x float> @llvm.vp.fadd.v4i32(<4 x float>, <4 x float>, <4 x i1>, i32) +declare <4 x float> @llvm.vp.fadd.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) define void @sink_splat_vp_fadd(float* nocapture %a, float %x, <4 x i1> %m, i32 zeroext %vl) { ; CHECK-LABEL: sink_splat_vp_fadd: @@ -3412,7 +3412,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fadd.v4i32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fadd.v4f32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3423,7 +3423,7 @@ ret void } -declare <4 x float> @llvm.vp.fsub.v4i32(<4 x float>, <4 x float>, <4 x i1>, i32) +declare <4 x float> @llvm.vp.fsub.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) define void @sink_splat_vp_fsub(float* nocapture %a, float %x, <4 x i1> %m, i32 zeroext %vl) { ; CHECK-LABEL: sink_splat_vp_fsub: @@ -3452,7 +3452,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fsub.v4i32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %wide.load, <4 x float> %broadcast.splat, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3463,7 +3463,7 @@ ret void } -declare <4 x float> @llvm.vp.frsub.v4i32(<4 x float>, <4 x float>, <4 x i1>, i32) +declare <4 x float> @llvm.vp.frsub.v4f32(<4 x float>, <4 x float>, <4 x i1>, i32) define void @sink_splat_vp_frsub(float* nocapture %a, float %x, <4 x i1> %m, i32 zeroext %vl) { ; CHECK-LABEL: sink_splat_vp_frsub: @@ -3492,7 +3492,7 @@ %0 = getelementptr inbounds float, float* %a, i64 %index %1 = bitcast float* %0 to <4 x float>* %wide.load = load <4 x float>, <4 x float>* %1, align 4 - %2 = call <4 x float> @llvm.vp.fsub.v4i32(<4 x float> %broadcast.splat, <4 x float> %wide.load, <4 x i1> %m, i32 %vl) + %2 = call <4 x float> @llvm.vp.fsub.v4f32(<4 x float> %broadcast.splat, <4 x float> %wide.load, <4 x i1> %m, i32 %vl) %3 = bitcast float* %0 to <4 x float>* store <4 x float> %2, <4 x float>* %3, align 4 %index.next = add nuw i64 %index, 4 @@ -3791,3 +3791,126 @@ for.cond.cleanup: ; preds = %vector.body ret void } + +declare @llvm.vp.add.nxv4i32(, , , i32) + +define void @sink_all_ones_mask_vp_add(i32* nocapture %a, i32 signext %x, i32 zeroext %vl) { +; CHECK-LABEL: sink_all_ones_mask_vp_add: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a3, 1024 +; CHECK-NEXT: .LBB67_1: # %vector.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vl2re32.v v8, (a0) +; CHECK-NEXT: vsetvli zero, a2, e32, m2, ta, mu +; CHECK-NEXT: vadd.vx v8, v8, a1 +; CHECK-NEXT: vs2r.v v8, (a0) +; CHECK-NEXT: addi a3, a3, -4 +; CHECK-NEXT: addi a0, a0, 16 +; CHECK-NEXT: bnez a3, .LBB67_1 +; CHECK-NEXT: # %bb.2: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %broadcast.splatinsert = insertelement poison, i32 %x, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %mask.splatinsert = insertelement poison, i1 -1, i32 0 + %m = shufflevector %mask.splatinsert, poison, zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = getelementptr inbounds i32, i32* %a, i64 %index + %1 = bitcast i32* %0 to * + %wide.load = load , * %1, align 4 + %2 = call @llvm.vp.add.nxv4i32( %wide.load, %broadcast.splat, %m, i32 %vl) + %3 = bitcast i32* %0 to * + store %2, * %3, align 4 + %index.next = add nuw i64 %index, 4 + %4 = icmp eq i64 %index.next, 1024 + br i1 %4, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +} + +declare @llvm.vp.fma.nxv4f32(, , , , i32) + +define void @sink_all_ones_mask_vp_fma(float* noalias nocapture %a, float* nocapture readonly %b, float %x, i32 zeroext %vl) { +; CHECK-LABEL: sink_all_ones_mask_vp_fma: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a3, 1024 +; CHECK-NEXT: .LBB68_1: # %vector.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vl2re32.v v8, (a0) +; CHECK-NEXT: vl2re32.v v10, (a1) +; CHECK-NEXT: vsetvli zero, a2, e32, m2, ta, mu +; CHECK-NEXT: vfmacc.vf v10, fa0, v8 +; CHECK-NEXT: vs2r.v v10, (a0) +; CHECK-NEXT: addi a3, a3, -4 +; CHECK-NEXT: addi a1, a1, 16 +; CHECK-NEXT: addi a0, a0, 16 +; CHECK-NEXT: bnez a3, .LBB68_1 +; CHECK-NEXT: # %bb.2: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %broadcast.splatinsert = insertelement poison, float %x, i32 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %mask.splatinsert = insertelement poison, i1 -1, i32 0 + %m = shufflevector %mask.splatinsert, poison, zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = getelementptr inbounds float, float* %a, i64 %index + %1 = bitcast float* %0 to * + %wide.load = load , * %1, align 4 + %2 = getelementptr inbounds float, float* %b, i64 %index + %3 = bitcast float* %2 to * + %wide.load12 = load , * %3, align 4 + %4 = call @llvm.vp.fma.nxv4f32( %wide.load, %broadcast.splat, %wide.load12, %m, i32 %vl) + %5 = bitcast float* %0 to * + store %4, * %5, align 4 + %index.next = add nuw i64 %index, 4 + %6 = icmp eq i64 %index.next, 1024 + br i1 %6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +} + +declare @llvm.vp.fneg.nxv4f32(, , i32) + +define void @sink_all_ones_mask_vp_fneg(float* nocapture %a, i32 zeroext %vl) { +; CHECK-LABEL: sink_all_ones_mask_vp_fneg: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a2, 1024 +; CHECK-NEXT: .LBB69_1: # %vector.body +; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 +; CHECK-NEXT: vl2re32.v v8, (a0) +; CHECK-NEXT: vsetvli zero, a1, e32, m2, ta, mu +; CHECK-NEXT: vfsgnjn.vv v8, v8, v8 +; CHECK-NEXT: vs2r.v v8, (a0) +; CHECK-NEXT: addi a2, a2, -4 +; CHECK-NEXT: addi a0, a0, 16 +; CHECK-NEXT: bnez a2, .LBB69_1 +; CHECK-NEXT: # %bb.2: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %mask.splatinsert = insertelement poison, i1 -1, i32 0 + %m = shufflevector %mask.splatinsert, poison, zeroinitializer + br label %vector.body + +vector.body: ; preds = %vector.body, %entry + %index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ] + %0 = getelementptr inbounds float, float* %a, i64 %index + %1 = bitcast float* %0 to * + %wide.load = load , * %1, align 4 + %2 = call @llvm.vp.fneg.nxv4f32( %wide.load, %m, i32 %vl) + %3 = bitcast float* %0 to * + store %2, * %3, align 4 + %index.next = add nuw i64 %index, 4 + %4 = icmp eq i64 %index.next, 1024 + br i1 %4, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.body + ret void +}