diff --git a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertVSETVLI.cpp @@ -59,6 +59,7 @@ uint8_t MaskAgnostic : 1; uint8_t MaskRegOp : 1; uint8_t StoreOp : 1; + uint8_t ScalarMovOp : 1; uint8_t SEWLMULRatioOnly : 1; public: @@ -96,6 +97,21 @@ assert(hasAVLImm()); return AVLImm; } + bool hasZeroAVL() const { + if (hasAVLImm()) { + return getAVLImm() == 0; + } + return false; + } + bool hasPositiveAVL() const { + if (hasAVLImm()) { + return getAVLImm() > 0; + } + if (hasAVLReg()) { + return getAVLReg() == RISCV::X0; + } + return false; + } bool hasSameAVL(const VSETVLIInfo &Other) const { assert(isValid() && Other.isValid() && @@ -120,7 +136,7 @@ MaskAgnostic = RISCVVType::isMaskAgnostic(VType); } void setVTYPE(RISCVII::VLMUL L, unsigned S, bool TA, bool MA, bool MRO, - bool IsStore) { + bool IsStore, bool IsScalarMovOp) { assert(isValid() && !isUnknown() && "Can't set VTYPE for uninitialized or unknown"); VLMul = L; @@ -129,6 +145,7 @@ MaskAgnostic = MA; MaskRegOp = MRO; StoreOp = IsStore; + ScalarMovOp = IsScalarMovOp; } unsigned encodeVTYPE() const { @@ -139,6 +156,14 @@ bool hasSEWLMULRatioOnly() const { return SEWLMULRatioOnly; } + bool hasSameSEW(const VSETVLIInfo &Other) const { + assert(isValid() && Other.isValid() && + "Can't compare invalid VSETVLIInfos"); + assert(!isUnknown() && !Other.isUnknown() && + "Can't compare VTYPE in unknown state"); + return SEW == Other.SEW; + } + bool hasSameVTYPE(const VSETVLIInfo &Other) const { assert(isValid() && Other.isValid() && "Can't compare invalid VSETVLIInfos"); @@ -178,6 +203,15 @@ return getSEWLMULRatio() == Other.getSEWLMULRatio(); } + bool hasSamePolicy(const VSETVLIInfo &Other) const { + assert(isValid() && Other.isValid() && + "Can't compare invalid VSETVLIInfos"); + assert(!isUnknown() && !Other.isUnknown() && + "Can't compare VTYPE in unknown state"); + return TailAgnostic == Other.TailAgnostic && + MaskAgnostic == Other.MaskAgnostic; + } + bool hasCompatibleVTYPE(const VSETVLIInfo &InstrInfo, bool Strict) const { // Simple case, see if full VTYPE matches. if (hasSameVTYPE(InstrInfo)) @@ -222,6 +256,15 @@ return true; } + // For vmv.s.x and vfmv.s.f, there is only two behaves, VL = 0 and VL > 0. + // So it's compatible when we could make sure that all two VL be the same + // situation. + if (InstrInfo.ScalarMovOp && InstrInfo.hasAVLImm() && + ((hasPositiveAVL() && InstrInfo.hasPositiveAVL()) || + (hasZeroAVL() && InstrInfo.hasZeroAVL())) && + hasSameVTYPE(InstrInfo)) + return true; + // The AVL must match. if (!hasSameAVL(InstrInfo)) return false; @@ -414,6 +457,42 @@ } } +bool IsScalarMoveInstr(const MachineInstr &MI) { + switch (MI.getOpcode()) { + default: + return false; + case RISCV::PseudoVMV_S_X_M1: + case RISCV::PseudoVMV_S_X_M2: + case RISCV::PseudoVMV_S_X_M4: + case RISCV::PseudoVMV_S_X_M8: + case RISCV::PseudoVMV_S_X_MF2: + case RISCV::PseudoVMV_S_X_MF4: + case RISCV::PseudoVMV_S_X_MF8: + case RISCV::PseudoVFMV_F16_S_M1: + case RISCV::PseudoVFMV_F16_S_M2: + case RISCV::PseudoVFMV_F16_S_M4: + case RISCV::PseudoVFMV_F16_S_M8: + case RISCV::PseudoVFMV_F16_S_MF2: + case RISCV::PseudoVFMV_F16_S_MF4: + case RISCV::PseudoVFMV_F16_S_MF8: + case RISCV::PseudoVFMV_F32_S_M1: + case RISCV::PseudoVFMV_F32_S_M2: + case RISCV::PseudoVFMV_F32_S_M4: + case RISCV::PseudoVFMV_F32_S_M8: + case RISCV::PseudoVFMV_F32_S_MF2: + case RISCV::PseudoVFMV_F32_S_MF4: + case RISCV::PseudoVFMV_F32_S_MF8: + case RISCV::PseudoVFMV_F64_S_M1: + case RISCV::PseudoVFMV_F64_S_M2: + case RISCV::PseudoVFMV_F64_S_M4: + case RISCV::PseudoVFMV_F64_S_M8: + case RISCV::PseudoVFMV_F64_S_MF2: + case RISCV::PseudoVFMV_F64_S_MF4: + case RISCV::PseudoVFMV_F64_S_MF8: + return true; + } +} + static VSETVLIInfo computeInfoForInstr(const MachineInstr &MI, uint64_t TSFlags, const MachineRegisterInfo *MRI) { VSETVLIInfo InstrInfo; @@ -461,6 +540,7 @@ // If there are no explicit defs, this is a store instruction which can // ignore the tail and mask policies. bool StoreOp = MI.getNumExplicitDefs() == 0; + bool ScalarMovOp = IsScalarMoveInstr(MI); if (RISCVII::hasVLOp(TSFlags)) { const MachineOperand &VLOp = MI.getOperand(NumOperands - 2); @@ -477,7 +557,7 @@ } else InstrInfo.setAVLReg(RISCV::NoRegister); InstrInfo.setVTYPE(VLMul, SEW, /*TailAgnostic*/ TailAgnostic, - /*MaskAgnostic*/ false, MaskRegOp, StoreOp); + /*MaskAgnostic*/ false, MaskRegOp, StoreOp, ScalarMovOp); return InstrInfo; } @@ -1000,6 +1080,13 @@ PrevVSETVLIMI->getOperand(2).setImm(NewInfo.encodeVTYPE()); NeedInsertVSETVLI = false; } + if (IsScalarMoveInstr(MI) && + ((CurInfo.hasPositiveAVL() && NewInfo.hasPositiveAVL()) || + (CurInfo.hasZeroAVL() && NewInfo.hasZeroAVL())) && + NewInfo.hasSameVLMAX(CurInfo)) { + PrevVSETVLIMI->getOperand(2).setImm(NewInfo.encodeVTYPE()); + NeedInsertVSETVLI = false; + } } if (NeedInsertVSETVLI) insertVSETVLI(MBB, MI, NewInfo, CurInfo); diff --git a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll --- a/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll +++ b/llvm/test/CodeGen/RISCV/rvv/vsetvli-insert.ll @@ -147,8 +147,7 @@ define @test7( %a, i64 %b, %mask) nounwind { ; CHECK-LABEL: test7: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetvli a1, zero, e64, m1, ta, mu -; CHECK-NEXT: vsetivli zero, 1, e64, m1, tu, mu +; CHECK-NEXT: vsetvli a1, zero, e64, m1, tu, mu ; CHECK-NEXT: vmv.s.x v8, a0 ; CHECK-NEXT: ret entry: @@ -163,8 +162,7 @@ define @test8( %a, i64 %b, %mask) nounwind { ; CHECK-LABEL: test8: ; CHECK: # %bb.0: # %entry -; CHECK-NEXT: vsetivli a1, 6, e64, m1, ta, mu -; CHECK-NEXT: vsetivli zero, 2, e64, m1, tu, mu +; CHECK-NEXT: vsetivli a1, 6, e64, m1, tu, mu ; CHECK-NEXT: vmv.s.x v8, a0 ; CHECK-NEXT: ret entry: @@ -178,7 +176,6 @@ ; CHECK: # %bb.0: # %entry ; CHECK-NEXT: vsetivli zero, 9, e64, m1, tu, mu ; CHECK-NEXT: vadd.vv v8, v8, v8, v0.t -; CHECK-NEXT: vsetivli zero, 2, e64, m1, tu, mu ; CHECK-NEXT: vmv.s.x v8, a0 ; CHECK-NEXT: ret entry: