diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -199,6 +199,13 @@ if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) { setOperationAction(ISD::MUL, MVT::i32, Custom); + + setOperationAction(ISD::SDIV, MVT::i8, Custom); + setOperationAction(ISD::UDIV, MVT::i8, Custom); + setOperationAction(ISD::UREM, MVT::i8, Custom); + setOperationAction(ISD::SDIV, MVT::i16, Custom); + setOperationAction(ISD::UDIV, MVT::i16, Custom); + setOperationAction(ISD::UREM, MVT::i16, Custom); setOperationAction(ISD::SDIV, MVT::i32, Custom); setOperationAction(ISD::UDIV, MVT::i32, Custom); setOperationAction(ISD::UREM, MVT::i32, Custom); @@ -1433,11 +1440,12 @@ // be promoted to i64, making it difficult to select the SLLW/DIVUW/.../*W // later one because the fact the operation was originally of type i32 is // lost. -static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG) { +static SDValue customLegalizeToWOp(SDNode *N, SelectionDAG &DAG, + unsigned ExtOpc = ISD::ANY_EXTEND) { SDLoc DL(N); RISCVISD::NodeType WOpcode = getRISCVWOpcode(N->getOpcode()); - SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(0)); - SDValue NewOp1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i64, N->getOperand(1)); + SDValue NewOp0 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(0)); + SDValue NewOp1 = DAG.getNode(ExtOpc, DL, MVT::i64, N->getOperand(1)); SDValue NewRes = DAG.getNode(WOpcode, DL, MVT::i64, NewOp0, NewOp1); // ReplaceNodeResults requires we maintain the same type for the return value. return DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewRes); @@ -1534,14 +1542,24 @@ break; case ISD::SDIV: case ISD::UDIV: - case ISD::UREM: - assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && - Subtarget.hasStdExtM() && "Unexpected custom legalisation"); + case ISD::UREM: { + MVT VT = N->getSimpleValueType(0); + assert((VT == MVT::i8 || VT == MVT::i16 || VT == MVT::i32) && + Subtarget.is64Bit() && Subtarget.hasStdExtM() && + "Unexpected custom legalisation"); if (N->getOperand(0).getOpcode() == ISD::Constant || N->getOperand(1).getOpcode() == ISD::Constant) return; - Results.push_back(customLegalizeToWOp(N, DAG)); + + // If the input is i32, use ANY_EXTEND since the W instructions don't read + // the upper 32 bits. For other types we need to sign or zero extend + // based on the opcode. + unsigned ExtOpc = VT == MVT::i32 ? ISD::ANY_EXTEND + : N->getOpcode() == ISD::SDIV ? ISD::SIGN_EXTEND + : ISD::ZERO_EXTEND; + Results.push_back(customLegalizeToWOp(N, DAG, ExtOpc)); break; + } case ISD::BITCAST: { assert(((N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && Subtarget.hasStdExtF()) || @@ -2142,6 +2160,7 @@ const APInt &DemandedElts, const SelectionDAG &DAG, unsigned Depth) const { + unsigned BitWidth = Known.getBitWidth(); unsigned Opc = Op.getOpcode(); assert((Opc >= ISD::BUILTIN_OP_END || Opc == ISD::INTRINSIC_WO_CHAIN || @@ -2153,6 +2172,26 @@ Known.resetAll(); switch (Opc) { default: break; + case RISCVISD::REMUW: { + KnownBits Known2; + Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + // We only care about the lower 32 bits. + Known = KnownBits::urem(Known.trunc(32), Known2.trunc(32)); + // Restore the original width by sign extending. + Known = Known.sext(BitWidth); + break; + } + case RISCVISD::DIVUW: { + KnownBits Known2; + Known = DAG.computeKnownBits(Op.getOperand(0), DemandedElts, Depth + 1); + Known2 = DAG.computeKnownBits(Op.getOperand(1), DemandedElts, Depth + 1); + // We only care about the lower 32 bits. + Known = KnownBits::udiv(Known.trunc(32), Known2.trunc(32)); + // Restore the original width by sign extending. + Known = Known.sext(BitWidth); + break; + } case RISCVISD::READ_VLENB: // We assume VLENB is at least 8 bytes. // FIXME: The 1.0 draft spec defines minimum VLEN as 128 bits. diff --git a/llvm/test/CodeGen/RISCV/rv64m-exhaustive-w-insts.ll b/llvm/test/CodeGen/RISCV/rv64m-exhaustive-w-insts.ll --- a/llvm/test/CodeGen/RISCV/rv64m-exhaustive-w-insts.ll +++ b/llvm/test/CodeGen/RISCV/rv64m-exhaustive-w-insts.ll @@ -529,7 +529,7 @@ define zeroext i8 @zext_divuw_zext_zext_i8(i8 zeroext %a, i8 zeroext %b) nounwind { ; RV64IM-LABEL: zext_divuw_zext_zext_i8: ; RV64IM: # %bb.0: -; RV64IM-NEXT: divu a0, a0, a1 +; RV64IM-NEXT: divuw a0, a0, a1 ; RV64IM-NEXT: ret %1 = udiv i8 %a, %b ret i8 %1 @@ -538,7 +538,7 @@ define zeroext i16 @zext_divuw_zext_zext_i16(i16 zeroext %a, i16 zeroext %b) nounwind { ; RV64IM-LABEL: zext_divuw_zext_zext_i16: ; RV64IM: # %bb.0: -; RV64IM-NEXT: divu a0, a0, a1 +; RV64IM-NEXT: divuw a0, a0, a1 ; RV64IM-NEXT: ret %1 = udiv i16 %a, %b ret i16 %1 @@ -808,9 +808,7 @@ define signext i8 @sext_divw_sext_sext_i8(i8 signext %a, i8 signext %b) nounwind { ; RV64IM-LABEL: sext_divw_sext_sext_i8: ; RV64IM: # %bb.0: -; RV64IM-NEXT: div a0, a0, a1 -; RV64IM-NEXT: slli a0, a0, 56 -; RV64IM-NEXT: srai a0, a0, 56 +; RV64IM-NEXT: divw a0, a0, a1 ; RV64IM-NEXT: ret %1 = sdiv i8 %a, %b ret i8 %1 @@ -819,9 +817,7 @@ define signext i16 @sext_divw_sext_sext_i16(i16 signext %a, i16 signext %b) nounwind { ; RV64IM-LABEL: sext_divw_sext_sext_i16: ; RV64IM: # %bb.0: -; RV64IM-NEXT: div a0, a0, a1 -; RV64IM-NEXT: slli a0, a0, 48 -; RV64IM-NEXT: srai a0, a0, 48 +; RV64IM-NEXT: divw a0, a0, a1 ; RV64IM-NEXT: ret %1 = sdiv i16 %a, %b ret i16 %1 @@ -1372,7 +1368,7 @@ define zeroext i8 @zext_remuw_zext_zext_i8(i8 zeroext %a, i8 zeroext %b) nounwind { ; RV64IM-LABEL: zext_remuw_zext_zext_i8: ; RV64IM: # %bb.0: -; RV64IM-NEXT: remu a0, a0, a1 +; RV64IM-NEXT: remuw a0, a0, a1 ; RV64IM-NEXT: ret %1 = urem i8 %a, %b ret i8 %1 @@ -1381,7 +1377,7 @@ define zeroext i16 @zext_remuw_zext_zext_i16(i16 zeroext %a, i16 zeroext %b) nounwind { ; RV64IM-LABEL: zext_remuw_zext_zext_i16: ; RV64IM: # %bb.0: -; RV64IM-NEXT: remu a0, a0, a1 +; RV64IM-NEXT: remuw a0, a0, a1 ; RV64IM-NEXT: ret %1 = urem i16 %a, %b ret i16 %1