diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -42,6 +42,62 @@ } // namespace RISCV } // namespace llvm +// Returns VL if input SDNode has it as operand. +static SDValue getVLOperand(const SDNode *Op) { + // Try to get operand that might be IntrinsicID from llvm::Intrinsic. + constexpr unsigned IIDOpNo = 1; + if (Op->getNumOperands() < IIDOpNo + 1) + return SDValue(); + auto *IIDOp = dyn_cast(Op->getOperand(IIDOpNo).getNode()); + if (!IIDOp) + return SDValue(); + + // Get RISCVVIntrinsicInfo to find location of VL operand if any. + const RISCVVIntrinsicsTable::RISCVVIntrinsicInfo *II = + RISCVVIntrinsicsTable::getRISCVVIntrinsicInfo(IIDOp->getZExtValue()); + if (!II || !II->hasVLOperand()) + return SDValue(); + return Op->getOperand(II->VLOperand + 1 /* IID */ + 1 /* chain */); +} + +// Returns SDValue that holds the maximum zero extended value of VL operands in +// the range of nodes. If the range contains a node without VL operand or VL +// operand is a non-constant value, empty SDValue will be returned. +static SDValue findMaxVLConstant(iterator_range Range) { + assert(!Range.empty() && "Invalid range"); + std::pair MaxVL{SDValue{}, 0}; + for (SDNode *Node : Range) { + SDValue VL = getVLOperand(Node); + if (!VL) + return SDValue(); + auto *ConstantNode = dyn_cast(VL.getNode()); + if (!ConstantNode) + return SDValue(); + uint64_t VLValue = ConstantNode->getZExtValue(); + if (MaxVL.second < VLValue) + MaxVL = std::make_pair(VL, VLValue); + } + return MaxVL.first; +} + +// Returns common VL for users from the input range if any. +static SDValue getCommonVL(iterator_range &&Range) { + if (Range.empty()) + return SDValue(); + + // If all VL operands are known constants, find the max VL and return it. + if (SDValue ConstantVL = findMaxVLConstant(Range)) + return ConstantVL; + + // Check whether VL operands are the same. Return common non-constant VL. + SDValue VL = getVLOperand(*Range.begin()); + if (all_of(drop_begin(Range), [VL](SDNode *U) { + return getVLOperand(U).getNode() == VL.getNode(); + })) + return VL; + return SDValue(); +} + void RISCVDAGToDAGISel::PreprocessISelDAG() { SelectionDAG::allnodes_iterator Position = CurDAG->allnodes_end(); @@ -60,7 +116,11 @@ unsigned Opc = VT.isInteger() ? RISCVISD::VMV_V_X_VL : RISCVISD::VFMV_V_F_VL; SDLoc DL(N); - SDValue VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT()); + // If all users of the splat have the same VL, we can use it for VMV. Use + // X0 otherwise. + SDValue VL = getCommonVL(N->uses()); + if (!VL) + VL = CurDAG->getRegister(RISCV::X0, Subtarget->getXLenVT()); Result = CurDAG->getNode(Opc, DL, VT, CurDAG->getUNDEF(VT), N->getOperand(0), VL); break; diff --git a/llvm/test/CodeGen/RISCV/rvv/pr55615.ll b/llvm/test/CodeGen/RISCV/rvv/pr55615.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/rvv/pr55615.ll @@ -0,0 +1,72 @@ +; RUN: llc -mtriple=riscv64 -mattr=+v < %s | FileCheck %s + +define void @vector_splat_toggle_const_eq(double* %a, double* %b) { +; CHECK-LABEL: vector_splat_toggle_const_eq +; CHECK: vsetivli zero, 4, e64, m1, ta, mu +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vse64.v v8, (a0) +; CHECK-NEXT: vse64.v v8, (a1) +; CHECK-NEXT: ret + %addr = bitcast double* %a to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr, i64 4) + %addr2 = bitcast double* %b to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr2, i64 4) + ret void +} + +define void @vector_splat_toggle_const_ne(double* %a, double* %b) { +; CHECK-LABEL: vector_splat_toggle_const_ne +; CHECK: vsetivli zero, 4, e64, m1, ta, mu +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vsetivli zero, 2, e64, m1, ta, mu +; CHECK-NEXT: vse64.v v8, (a0) +; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu +; CHECK-NEXT: vse64.v v8, (a1) +; CHECK-NEXT: ret + %addr = bitcast double* %a to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr, i64 2) + %addr2 = bitcast double* %b to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr2, i64 4) + ret void +} + +define void @vector_splat_toggle_nonconst_eq(double* %a, double* %b, i64 %n) { +; CHECK-LABEL: vector_splat_toggle_nonconst_eq +; CHECK: vsetvli zero, a2, e64, m1, ta, mu +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vse64.v v8, (a0) +; CHECK-NEXT: vse64.v v8, (a1) +; CHECK-NEXT: ret + %vl = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n, i64 3, i64 0) + %addr = bitcast double* %a to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr, i64 %vl) + %addr2 = bitcast double* %b to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr2, i64 %vl) + ret void +} + +; Negative test +define void @vector_splat_toggle_nonconst_ne(double* %a, double* %b, i64 %n1, i64 %n2) { +; CHECK-LABEL: vector_splat_toggle_nonconst_ne +; CHECK: vsetvli a2, a2, e64, m1, ta, mu +; CHECK-NEXT: vsetvli a3, a3, e64, m1, ta, mu +; CHECK-NEXT: vsetvli a4, zero, e64, m1, ta, mu +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: vsetvli zero, a2, e64, m1, ta, mu +; CHECK-NEXT: vse64.v v8, (a0) +; CHECK-NEXT: vsetvli zero, a3, e64, m1, ta, mu +; CHECK-NEXT: vse64.v v8, (a1) +; CHECK-NEXT: ret + %vl1 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n1, i64 3, i64 0) + %vl2 = tail call i64 @llvm.riscv.vsetvli.i64(i64 %n2, i64 3, i64 0) + %addr = bitcast double* %a to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr, i64 %vl1) + %addr2 = bitcast double* %b to * + tail call void @llvm.riscv.vse.nxv1f64.i64( zeroinitializer, * %addr2, i64 %vl2) + ret void +} + +; Function Attrs: nounwind writeonly +declare void @llvm.riscv.vse.nxv1f64.i64(, * nocapture, i64) + +declare i64 @llvm.riscv.vsetvli.i64(i64, i64, i64) diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -677,7 +677,6 @@ ; CHECK-NEXT: # %bb.1: # %for.body.preheader ; CHECK-NEXT: li a3, 0 ; CHECK-NEXT: slli a4, a2, 3 -; CHECK-NEXT: vsetvli a5, zero, e64, m1, ta, mu ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB13_2: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 @@ -713,11 +712,9 @@ ; CHECK-NEXT: li a2, 0 ; CHECK-NEXT: vsetivli a3, 4, e64, m1, ta, mu ; CHECK-NEXT: slli a4, a3, 3 -; CHECK-NEXT: vsetvli a5, zero, e64, m1, ta, mu ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB14_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: add a2, a2, a3 ; CHECK-NEXT: add a1, a1, a4 @@ -747,11 +744,10 @@ ; CHECK-LABEL: vector_init_vsetvli_fv2: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: li a2, 0 -; CHECK-NEXT: vsetvli a3, zero, e64, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB15_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: addi a2, a2, 4 ; CHECK-NEXT: addi a1, a1, 32 @@ -781,11 +777,10 @@ ; CHECK-LABEL: vector_init_vsetvli_fv3: ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: li a2, 0 -; CHECK-NEXT: vsetvli a3, zero, e64, m1, ta, mu +; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vmv.v.i v8, 0 ; CHECK-NEXT: .LBB16_1: # %for.body ; CHECK-NEXT: # =>This Inner Loop Header: Depth=1 -; CHECK-NEXT: vsetivli zero, 4, e64, m1, ta, mu ; CHECK-NEXT: vse64.v v8, (a1) ; CHECK-NEXT: addi a2, a2, 4 ; CHECK-NEXT: addi a1, a1, 32