diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -58,10 +58,12 @@ uint8_t TailAgnostic : 1; uint8_t MaskAgnostic : 1; uint8_t MaskRegOp : 1; + uint8_t SEWLMULRatioOnly : 1; public: VSETVLIInfo() - : AVLImm(0), TailAgnostic(false), MaskAgnostic(false), MaskRegOp(false) {} + : AVLImm(0), TailAgnostic(false), MaskAgnostic(false), MaskRegOp(false), + SEWLMULRatioOnly(false) {} static VSETVLIInfo getUnknown() { VSETVLIInfo Info; @@ -127,16 +129,20 @@ } unsigned encodeVTYPE() const { - assert(isValid() && !isUnknown() && + assert(isValid() && !isUnknown() && !SEWLMULRatioOnly && "Can't encode VTYPE for uninitialized or unknown"); return RISCVVType::encodeVTYPE(VLMul, SEW, TailAgnostic, MaskAgnostic); } + bool hasSEWLMULRatioOnly() const { return SEWLMULRatioOnly; } + bool hasSameVTYPE(const VSETVLIInfo &Other) const { assert(isValid() && Other.isValid() && "Can't compare invalid VSETVLIInfos"); assert(!isUnknown() && !Other.isUnknown() && "Can't compare VTYPE in unknown state"); + assert(!SEWLMULRatioOnly && !Other.SEWLMULRatioOnly && + "Can't compare when only LMUL/SEW ratio is valid."); return std::tie(VLMul, SEW, TailAgnostic, MaskAgnostic) == std::tie(Other.VLMul, Other.SEW, Other.TailAgnostic, Other.MaskAgnostic); @@ -172,10 +178,16 @@ bool isCompatible(const VSETVLIInfo &InstrInfo) const { assert(isValid() && InstrInfo.isValid() && "Can't compare invalid VSETVLIInfos"); + assert(!InstrInfo.SEWLMULRatioOnly && + "Expected a valid VTYPE for instruction!"); // Nothing is compatible with Unknown. if (isUnknown() || InstrInfo.isUnknown()) return false; + // If only our VLMAX ratio is valid, then this isn't compatible. + if (SEWLMULRatioOnly) + return false; + // If the instruction doesn't need an AVLReg and the SEW matches, consider // it/ compatible. if (InstrInfo.hasAVLReg() && InstrInfo.AVLReg == RISCV::NoRegister) { @@ -209,8 +221,19 @@ if (Other.isUnknown()) return isUnknown(); - // Otherwise compare the VTYPE and AVL. - return hasSameVTYPE(Other) && hasSameAVL(Other); + if (!hasSameAVL(Other)) + return false; + + // If only the VLMAX is valid, check that it is the same. + if (SEWLMULRatioOnly && Other.SEWLMULRatioOnly) + return hasSameVLMAX(Other); + + // If the full VTYPE is valid, check that it is the same. + if (!SEWLMULRatioOnly && !Other.SEWLMULRatioOnly) + return hasSameVTYPE(Other); + + // If the SEWLMULRatioOnly bits are different, then they aren't equal. + return false; } // Calculate the VSETVLIInfo visible to a block assuming this and Other are @@ -224,10 +247,23 @@ if (!isValid()) return Other; + // If either is unknown, the result is unknown. + if (isUnknown() || Other.isUnknown()) + return VSETVLIInfo::getUnknown(); + + // If we have an exact, match return this. if (*this == Other) return *this; - // If the configurations don't match, assume unknown. + // Not an exact match, but maybe the AVL and VLMAX are the same. If so, + // return an SEW/LMUL ratio only value. + if (hasSameAVL(Other) && hasSameVLMAX(Other)) { + VSETVLIInfo MergeInfo = *this; + MergeInfo.SEWLMULRatioOnly = true; + return MergeInfo; + } + + // Otherwise the result is unknown. return VSETVLIInfo::getUnknown(); } @@ -444,7 +480,8 @@ // and the last VL/VTYPE we observed is the same, we don't need a // VSETVLI here. if (!CurInfo.isUnknown() && Require.hasAVLReg() && - Require.getAVLReg().isVirtual() && Require.hasSameVTYPE(CurInfo)) { + Require.getAVLReg().isVirtual() && !CurInfo.hasSEWLMULRatioOnly() && + Require.hasSameVTYPE(CurInfo)) { if (MachineInstr *DefMI = MRI->getVRegDef(Require.getAVLReg())) { if (DefMI->getOpcode() == RISCV::PseudoVSETVLI || DefMI->getOpcode() == RISCV::PseudoVSETIVLI) { diff --git a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-unaligned.ll b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-unaligned.ll --- a/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-unaligned.ll +++ b/llvm/test/CodeGen/RISCV/rvv/fixed-vectors-unaligned.ll @@ -584,7 +584,7 @@ ; RV32-NEXT: vsetvli zero, zero, e32, mf2, tu, mu ; RV32-NEXT: vslideup.vi v25, v26, 1 ; RV32-NEXT: .LBB8_4: # %else2 -; RV32-NEXT: vsetivli zero, 2, e32, mf2, ta, mu +; RV32-NEXT: vsetvli zero, zero, e32, mf2, ta, mu ; RV32-NEXT: vse32.v v25, (a1) ; RV32-NEXT: addi sp, sp, 16 ; RV32-NEXT: ret @@ -644,7 +644,7 @@ ; RV64-NEXT: vsetvli zero, zero, e32, mf2, tu, mu ; RV64-NEXT: vslideup.vi v25, v26, 1 ; RV64-NEXT: .LBB8_4: # %else2 -; RV64-NEXT: vsetivli zero, 2, e32, mf2, ta, mu +; RV64-NEXT: vsetvli zero, zero, e32, mf2, ta, mu ; RV64-NEXT: vse32.v v25, (a1) ; RV64-NEXT: addi sp, sp, 16 ; RV64-NEXT: ret diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert-crossbb.ll @@ -496,3 +496,92 @@ declare @llvm.riscv.vle.nxv16f32.i64(* nocapture, i64) declare @llvm.riscv.vfmacc.nxv16f32.f32.i64(, float, , i64) declare void @llvm.riscv.vse.nxv16f32.i64(, * nocapture, i64) + +; We need a vsetvli in the last block because the predecessors have different +; VTYPEs. The AVL is the same and the SEW/LMUL ratio implies the same VLMAX so +; we don't need to read AVL and can keep VL unchanged. +define @test_vsetvli_x0_x0(* %x, * %y, %z, i64 %vl, i1 %cond) nounwind { +; CHECK-LABEL: test_vsetvli_x0_x0: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli zero, a2, e32, m1, ta, mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: andi a0, a3, 1 +; CHECK-NEXT: beqz a0, .LBB9_2 +; CHECK-NEXT: # %bb.1: # %if +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu +; CHECK-NEXT: vle16.v v26, (a1) +; CHECK-NEXT: vwadd.vx v8, v26, zero +; CHECK-NEXT: .LBB9_2: # %if.end +; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu +; CHECK-NEXT: vadd.vv v8, v25, v8 +; CHECK-NEXT: ret +entry: + %a = call @llvm.riscv.vle.nxv2i32(* %x, i64 %vl) + br i1 %cond, label %if, label %if.end + +if: + %b = call @llvm.riscv.vle.nxv2i16(* %y, i64 %vl) + %c = call @llvm.riscv.vwadd.nxv2i32( %b, i16 0, i64 %vl) + br label %if.end + +if.end: + %d = phi [ %z, %entry ], [ %c, %if ] + %e = call @llvm.riscv.vadd.nxv2i32( %a, %d, i64 %vl) + ret %e +} +declare @llvm.riscv.vle.nxv2i32(*, i64) +declare @llvm.riscv.vle.nxv2i16(*, i64) +declare @llvm.riscv.vwadd.nxv2i32(, i16, i64) +declare @llvm.riscv.vadd.nxv2i32(, , i64) + +; We can use X0, X0 vsetvli in if2 and if2.end. The merge point as if.end will +; see two different vtypes with the same SEW/LMUL ratio. At if2.end we will only +; know the SEW/LMUL ratio for the if.end predecessor and the full vtype for +; the if2 predecessor. This makes sure we can merge a SEW/LMUL predecessor with +; a predecessor we know the vtype for. +define @test_vsetvli_x0_x0_2(* %x, * %y, * %z, i64 %vl, i1 %cond, i1 %cond2, %w) nounwind { +; CHECK-LABEL: test_vsetvli_x0_x0_2: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: vsetvli zero, a3, e32, m1, ta, mu +; CHECK-NEXT: vle32.v v25, (a0) +; CHECK-NEXT: andi a0, a4, 1 +; CHECK-NEXT: beqz a0, .LBB10_2 +; CHECK-NEXT: # %bb.1: # %if +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu +; CHECK-NEXT: vle16.v v26, (a1) +; CHECK-NEXT: vwadd.wv v25, v25, v26 +; CHECK-NEXT: .LBB10_2: # %if.end +; CHECK-NEXT: andi a0, a5, 1 +; CHECK-NEXT: beqz a0, .LBB10_4 +; CHECK-NEXT: # %bb.3: # %if2 +; CHECK-NEXT: vsetvli zero, zero, e16, mf2, ta, mu +; CHECK-NEXT: vle16.v v26, (a2) +; CHECK-NEXT: vwadd.wv v25, v25, v26 +; CHECK-NEXT: .LBB10_4: # %if2.end +; CHECK-NEXT: vsetvli zero, zero, e32, m1, ta, mu +; CHECK-NEXT: vadd.vv v8, v25, v8 +; CHECK-NEXT: ret +entry: + %a = call @llvm.riscv.vle.nxv2i32(* %x, i64 %vl) + br i1 %cond, label %if, label %if.end + +if: + %b = call @llvm.riscv.vle.nxv2i16(* %y, i64 %vl) + %c = call @llvm.riscv.vwadd.w.nxv2i32.nxv2i16( %a, %b, i64 %vl) + br label %if.end + +if.end: + %d = phi [ %a, %entry ], [ %c, %if ] + br i1 %cond2, label %if2, label %if2.end + +if2: + %e = call @llvm.riscv.vle.nxv2i16(* %z, i64 %vl) + %f = call @llvm.riscv.vwadd.w.nxv2i32.nxv2i16( %d, %e, i64 %vl) + br label %if2.end + +if2.end: + %g = phi [ %d, %if.end ], [ %f, %if2 ] + %h = call @llvm.riscv.vadd.nxv2i32( %g, %w, i64 %vl) + ret %h +} +declare @llvm.riscv.vwadd.w.nxv2i32.nxv2i16(, , i64)