diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -95,6 +95,13 @@ // compiler has free to select either one. UsesMaskPolicyShift = IsRVVWideningReductionShift + 1, UsesMaskPolicyMask = 1 << UsesMaskPolicyShift, + + // Indicates that the result can be considered sign extended from bit 31. Some + // instructions with this flag aren't W instructions, but are either sign + // extended from a smaller size, always outputs a small integer, or put zeros + // in bits 63:31. Used by the SExtWRemoval pass. + IsSignExtendingOpWShift = UsesMaskPolicyShift + 1, + IsSignExtendingOpWMask = 1ULL << IsSignExtendingOpWShift, }; // Match with the definitions in RISCVInstrFormats.td diff --git a/llvm/lib/Target/RISCV/RISCVInstrFormats.td b/llvm/lib/Target/RISCV/RISCVInstrFormats.td --- a/llvm/lib/Target/RISCV/RISCVInstrFormats.td +++ b/llvm/lib/Target/RISCV/RISCVInstrFormats.td @@ -204,6 +204,13 @@ bit UsesMaskPolicy = 0; let TSFlags{18} = UsesMaskPolicy; + + // Indicates that the result can be considered sign extended from bit 31. Some + // instructions with this flag aren't W instructions, but are either sign + // extended from a smaller size, always outputs a small integer, or put zeros + // in bits 63:31. Used by the SExtWRemoval pass. + bit IsSignExtendingOpW = 0; + let TSFlags{19} = IsSignExtendingOpW; } // Pseudo instructions diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfo.td b/llvm/lib/Target/RISCV/RISCVInstrInfo.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfo.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfo.td @@ -599,7 +599,8 @@ //===----------------------------------------------------------------------===// let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in { -let isReMaterializable = 1, isAsCheapAsAMove = 1 in +let isReMaterializable = 1, isAsCheapAsAMove = 1, + IsSignExtendingOpW = 1 in def LUI : RVInstU, Sched<[WriteIALU]>; @@ -624,11 +625,13 @@ def BLTU : BranchCC_rri<0b110, "bltu">; def BGEU : BranchCC_rri<0b111, "bgeu">; +let IsSignExtendingOpW = 1 in { def LB : Load_ri<0b000, "lb">, Sched<[WriteLDB, ReadMemBase]>; def LH : Load_ri<0b001, "lh">, Sched<[WriteLDH, ReadMemBase]>; def LW : Load_ri<0b010, "lw">, Sched<[WriteLDW, ReadMemBase]>; def LBU : Load_ri<0b100, "lbu">, Sched<[WriteLDB, ReadMemBase]>; def LHU : Load_ri<0b101, "lhu">, Sched<[WriteLDH, ReadMemBase]>; +} def SB : Store_rri<0b000, "sb">, Sched<[WriteSTB, ReadStoreData, ReadMemBase]>; def SH : Store_rri<0b001, "sh">, Sched<[WriteSTH, ReadStoreData, ReadMemBase]>; @@ -639,8 +642,10 @@ let isReMaterializable = 1, isAsCheapAsAMove = 1 in def ADDI : ALU_ri<0b000, "addi">; +let IsSignExtendingOpW = 1 in { def SLTI : ALU_ri<0b010, "slti">; def SLTIU : ALU_ri<0b011, "sltiu">; +} let isReMaterializable = 1, isAsCheapAsAMove = 1 in { def XORI : ALU_ri<0b100, "xori">; @@ -659,10 +664,12 @@ Sched<[WriteIALU, ReadIALU, ReadIALU]>; def SLL : ALU_rr<0b0000000, 0b001, "sll">, Sched<[WriteShiftReg, ReadShiftReg, ReadShiftReg]>; +let IsSignExtendingOpW = 1 in { def SLT : ALU_rr<0b0000000, 0b010, "slt">, Sched<[WriteIALU, ReadIALU, ReadIALU]>; def SLTU : ALU_rr<0b0000000, 0b011, "sltu">, Sched<[WriteIALU, ReadIALU, ReadIALU]>; +} def XOR : ALU_rr<0b0000000, 0b100, "xor", /*Commutable*/1>, Sched<[WriteIALU, ReadIALU, ReadIALU]>; def SRL : ALU_rr<0b0000000, 0b101, "srl">, @@ -754,6 +761,7 @@ def LD : Load_ri<0b011, "ld">, Sched<[WriteLDD, ReadMemBase]>; def SD : Store_rri<0b011, "sd">, Sched<[WriteSTD, ReadStoreData, ReadMemBase]>; +let IsSignExtendingOpW = 1 in { let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in def ADDIW : RVInstI<0b000, OPC_OP_IMM_32, (outs GPR:$rd), (ins GPR:$rs1, simm12:$imm12), @@ -774,6 +782,7 @@ Sched<[WriteShiftReg32, ReadShiftReg32, ReadShiftReg32]>; def SRAW : ALUW_rr<0b0100000, 0b101, "sraw">, Sched<[WriteShiftReg32, ReadShiftReg32, ReadShiftReg32]>; +} // IsSignExtendingOpW = 1 } // Predicates = [IsRV64] //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoD.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoD.td @@ -160,10 +160,12 @@ defm FCLASS_D : FPUnaryOp_r_m<0b1110001, 0b00000, 0b001, XDINX, "fclass.d">, Sched<[WriteFClass64, ReadFClass64]>; +let IsSignExtendingOpW = 1 in defm FCVT_W_D : FPUnaryOp_r_frm_m<0b1100001, 0b00000, XDINX, "fcvt.w.d">, Sched<[WriteFCvtF64ToI32, ReadFCvtF64ToI32]>; defm : FPUnaryOpDynFrmAlias_m; +let IsSignExtendingOpW = 1 in defm FCVT_WU_D : FPUnaryOp_r_frm_m<0b1100001, 0b00001, XDINX, "fcvt.wu.d">, Sched<[WriteFCvtF64ToI32, ReadFCvtF64ToI32]>; defm : FPUnaryOpDynFrmAlias_m; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoF.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoF.td @@ -282,7 +282,8 @@ OpcodeStr, Ext.RdTy, Ext.Rs1Ty>; } -let hasSideEffects = 0, mayLoad = 0, mayStore = 0, mayRaiseFPException = 1 in +let hasSideEffects = 0, mayLoad = 0, mayStore = 0, mayRaiseFPException = 1, + IsSignExtendingOpW = 1 in class FPCmp_rr funct7, bits<3> funct3, string opcodestr, DAGOperand rty, bit Commutable> : RVInstR; } +let IsSignExtendingOpW = 1 in defm FCVT_W_S : FPUnaryOp_r_frm_m<0b1100000, 0b00000, XFINX, "fcvt.w.s">, Sched<[WriteFCvtF32ToI32, ReadFCvtF32ToI32]>; defm : FPUnaryOpDynFrmAlias_m; +let IsSignExtendingOpW = 1 in defm FCVT_WU_S : FPUnaryOp_r_frm_m<0b1100000, 0b00001, XFINX, "fcvt.wu.s">, Sched<[WriteFCvtF32ToI32, ReadFCvtF32ToI32]>; defm : FPUnaryOpDynFrmAlias_m; -let Predicates = [HasStdExtF], mayRaiseFPException = 0 in +let Predicates = [HasStdExtF], mayRaiseFPException = 0, + IsSignExtendingOpW = 1 in def FMV_X_W : FPUnaryOp_r<0b1110000, 0b00000, 0b000, GPR, FPR32, "fmv.x.w">, Sched<[WriteFMovF32ToI32, ReadFMovF32ToI32]>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -46,12 +46,12 @@ Sched<[WriteIDiv, ReadIDiv, ReadIDiv]>; } // Predicates = [HasStdExtM] -let Predicates = [HasStdExtMOrZmmul, IsRV64] in { +let Predicates = [HasStdExtMOrZmmul, IsRV64], IsSignExtendingOpW = 1 in { def MULW : ALUW_rr<0b0000001, 0b000, "mulw", /*Commutable*/1>, Sched<[WriteIMul32, ReadIMul32, ReadIMul32]>; } // Predicates = [HasStdExtMOrZmmul, IsRV64] -let Predicates = [HasStdExtM, IsRV64] in { +let Predicates = [HasStdExtM, IsRV64], IsSignExtendingOpW = 1 in { def DIVW : ALUW_rr<0b0000001, 0b100, "divw">, Sched<[WriteIDiv32, ReadIDiv32, ReadIDiv32]>; def DIVUW : ALUW_rr<0b0000001, 0b101, "divuw">, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZb.td @@ -349,7 +349,7 @@ Sched<[WriteRotateImm, ReadRotateImm]>; } // Predicates = [HasStdExtZbbOrZbkb] -let Predicates = [HasStdExtZbbOrZbkb, IsRV64] in { +let Predicates = [HasStdExtZbbOrZbkb, IsRV64], IsSignExtendingOpW = 1 in { def ROLW : ALUW_rr<0b0110000, 0b001, "rolw">, Sched<[WriteRotateReg32, ReadRotateReg32, ReadRotateReg32]>; def RORW : ALUW_rr<0b0110000, 0b101, "rorw">, @@ -366,6 +366,7 @@ Sched<[WriteSingleBit, ReadSingleBit, ReadSingleBit]>; def BINV : ALU_rr<0b0110100, 0b001, "binv">, Sched<[WriteSingleBit, ReadSingleBit, ReadSingleBit]>; +let IsSignExtendingOpW = 1 in def BEXT : ALU_rr<0b0100100, 0b101, "bext">, Sched<[WriteSingleBit, ReadSingleBit, ReadSingleBit]>; @@ -375,6 +376,7 @@ Sched<[WriteSingleBitImm, ReadSingleBitImm]>; def BINVI : RVBShift_ri<0b01101, 0b001, OPC_OP_IMM, "binvi">, Sched<[WriteSingleBitImm, ReadSingleBitImm]>; +let IsSignExtendingOpW = 1 in def BEXTI : RVBShift_ri<0b01001, 0b101, OPC_OP_IMM, "bexti">, Sched<[WriteSingleBitImm, ReadSingleBitImm]>; } // Predicates = [HasStdExtZbs] @@ -389,7 +391,7 @@ Sched<[WriteXPERM, ReadXPERM, ReadXPERM]>; } // Predicates = [HasStdExtZbkx] -let Predicates = [HasStdExtZbb] in { +let Predicates = [HasStdExtZbb], IsSignExtendingOpW = 1 in { def CLZ : RVBUnary<0b0110000, 0b00000, 0b001, OPC_OP_IMM, "clz">, Sched<[WriteCLZ, ReadCLZ]>; def CTZ : RVBUnary<0b0110000, 0b00001, 0b001, OPC_OP_IMM, "ctz">, @@ -398,7 +400,7 @@ Sched<[WriteCPOP, ReadCPOP]>; } // Predicates = [HasStdExtZbb] -let Predicates = [HasStdExtZbb, IsRV64] in { +let Predicates = [HasStdExtZbb, IsRV64], IsSignExtendingOpW = 1 in { def CLZW : RVBUnary<0b0110000, 0b00000, 0b001, OPC_OP_IMM_32, "clzw">, Sched<[WriteCLZ32, ReadCLZ32]>; def CTZW : RVBUnary<0b0110000, 0b00001, 0b001, OPC_OP_IMM_32, "ctzw">, @@ -407,7 +409,7 @@ Sched<[WriteCPOP32, ReadCPOP32]>; } // Predicates = [HasStdExtZbb, IsRV64] -let Predicates = [HasStdExtZbb] in { +let Predicates = [HasStdExtZbb], IsSignExtendingOpW = 1 in { def SEXT_B : RVBUnary<0b0110000, 0b00100, 0b001, OPC_OP_IMM, "sext.b">, Sched<[WriteIALU, ReadIALU]>; def SEXT_H : RVBUnary<0b0110000, 0b00101, 0b001, OPC_OP_IMM, "sext.h">, @@ -440,11 +442,12 @@ let Predicates = [HasStdExtZbkb] in { def PACK : ALU_rr<0b0000100, 0b100, "pack">, Sched<[WritePACK, ReadPACK, ReadPACK]>; +let IsSignExtendingOpW = 1 in def PACKH : ALU_rr<0b0000100, 0b111, "packh">, Sched<[WritePACK, ReadPACK, ReadPACK]>; } // Predicates = [HasStdExtZbkb] -let Predicates = [HasStdExtZbkb, IsRV64] in +let Predicates = [HasStdExtZbkb, IsRV64], IsSignExtendingOpW = 1 in def PACKW : ALUW_rr<0b0000100, 0b100, "packw">, Sched<[WritePACK32, ReadPACK32, ReadPACK32]>; @@ -453,7 +456,7 @@ Sched<[WriteIALU, ReadIALU]>; } // Predicates = [HasStdExtZbb, IsRV32] -let Predicates = [HasStdExtZbb, IsRV64] in { +let Predicates = [HasStdExtZbb, IsRV64], IsSignExtendingOpW = 1 in { def ZEXT_H_RV64 : RVBUnary<0b0000100, 0b00000, 0b100, OPC_OP_32, "zext.h">, Sched<[WriteIALU, ReadIALU]>; } // Predicates = [HasStdExtZbb, IsRV64] diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td @@ -139,10 +139,12 @@ defm FMAX_H : FPALU_rr_m<0b0010110, 0b001, "fmax.h", HINX, /*Commutable*/1>; } +let IsSignExtendingOpW = 1 in defm FCVT_W_H : FPUnaryOp_r_frm_m<0b1100010, 0b00000, XHINX, "fcvt.w.h">, Sched<[WriteFCvtF16ToI32, ReadFCvtF16ToI32]>; defm : FPUnaryOpDynFrmAlias_m; +let IsSignExtendingOpW = 1 in defm FCVT_WU_H : FPUnaryOp_r_frm_m<0b1100010, 0b00001, XHINX, "fcvt.wu.h">, Sched<[WriteFCvtF16ToI32, ReadFCvtF16ToI32]>; defm : FPUnaryOpDynFrmAlias_m; @@ -163,7 +165,7 @@ Sched<[WriteFCvtF16ToF32, ReadFCvtF16ToF32]>; let Predicates = [HasStdExtZfhOrZfhmin] in { -let mayRaiseFPException = 0 in +let mayRaiseFPException = 0, IsSignExtendingOpW = 1 in def FMV_X_H : FPUnaryOp_r<0b1110010, 0b00000, 0b000, GPR, FPR16, "fmv.x.h">, Sched<[WriteFMovF16ToI16, ReadFMovF16ToI16]>; diff --git a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp --- a/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp +++ b/llvm/lib/Target/RISCV/RISCVSExtWRemoval.cpp @@ -229,70 +229,15 @@ // This function returns true if the machine instruction always outputs a value // where bits 63:32 match bit 31. -// TODO: Allocate a bit in TSFlags for the W instructions? -// TODO: Add other W instructions. static bool isSignExtendingOpW(MachineInstr &MI, MachineRegisterInfo &MRI) { - switch (MI.getOpcode()) { - case RISCV::LUI: - case RISCV::LW: - case RISCV::ADDW: - case RISCV::ADDIW: - case RISCV::SUBW: - case RISCV::MULW: - case RISCV::SLLW: - case RISCV::SLLIW: - case RISCV::SRAW: - case RISCV::SRAIW: - case RISCV::SRLW: - case RISCV::SRLIW: - case RISCV::DIVW: - case RISCV::DIVUW: - case RISCV::REMW: - case RISCV::REMUW: - case RISCV::ROLW: - case RISCV::RORW: - case RISCV::RORIW: - case RISCV::CLZW: - case RISCV::CTZW: - case RISCV::CPOPW: - case RISCV::PACKW: - case RISCV::FCVT_W_H: - case RISCV::FCVT_WU_H: - case RISCV::FCVT_W_S: - case RISCV::FCVT_WU_S: - case RISCV::FCVT_W_D: - case RISCV::FCVT_WU_D: - case RISCV::FMV_X_W: - // The following aren't W instructions, but are either sign extended from a - // smaller size, always outputs a small integer, or put zeros in bits 63:31. - case RISCV::LBU: - case RISCV::LHU: - case RISCV::LB: - case RISCV::LH: - case RISCV::SLT: - case RISCV::SLTI: - case RISCV::SLTU: - case RISCV::SLTIU: - case RISCV::FEQ_H: - case RISCV::FEQ_S: - case RISCV::FEQ_D: - case RISCV::FLT_H: - case RISCV::FLT_S: - case RISCV::FLT_D: - case RISCV::FLE_H: - case RISCV::FLE_S: - case RISCV::FLE_D: - case RISCV::SEXT_B: - case RISCV::SEXT_H: - case RISCV::ZEXT_H_RV64: - case RISCV::FMV_X_H: - case RISCV::BEXT: - case RISCV::BEXTI: - case RISCV::CLZ: - case RISCV::CPOP: - case RISCV::CTZ: - case RISCV::PACKH: + uint64_t TSFlags = MI.getDesc().TSFlags; + + // Instructions that can be determined from opcode are marked in tablegen. + if (TSFlags & RISCVII::IsSignExtendingOpWMask) return true; + + // Special cases that require checking operands. + switch (MI.getOpcode()) { // shifting right sufficiently makes the value 32-bit sign-extended case RISCV::SRAI: return MI.getOperand(2).getImm() >= 32; @@ -310,7 +255,6 @@ // Copying from X0 produces zero. case RISCV::COPY: return MI.getOperand(1).getReg() == RISCV::X0; - } return false;