diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -40,6 +40,8 @@ BuildPairF64, SplitF64, TAIL, + // Multiply high for signedxunsigned. + MULHSU, // RV64I shifts, directly matching the semantics of the named RISC-V // instructions. SLLW, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -219,20 +219,23 @@ setOperationAction(ISD::UDIV, XLenVT, Expand); setOperationAction(ISD::SREM, XLenVT, Expand); setOperationAction(ISD::UREM, XLenVT, Expand); - } - - if (Subtarget.is64Bit() && Subtarget.hasStdExtM()) { - setOperationAction(ISD::MUL, MVT::i32, Custom); - - setOperationAction(ISD::SDIV, MVT::i8, Custom); - setOperationAction(ISD::UDIV, MVT::i8, Custom); - setOperationAction(ISD::UREM, MVT::i8, Custom); - setOperationAction(ISD::SDIV, MVT::i16, Custom); - setOperationAction(ISD::UDIV, MVT::i16, Custom); - setOperationAction(ISD::UREM, MVT::i16, Custom); - setOperationAction(ISD::SDIV, MVT::i32, Custom); - setOperationAction(ISD::UDIV, MVT::i32, Custom); - setOperationAction(ISD::UREM, MVT::i32, Custom); + } else { + if (Subtarget.is64Bit()) { + setOperationAction(ISD::MUL, MVT::i32, Custom); + setOperationAction(ISD::MUL, MVT::i128, Custom); + + setOperationAction(ISD::SDIV, MVT::i8, Custom); + setOperationAction(ISD::UDIV, MVT::i8, Custom); + setOperationAction(ISD::UREM, MVT::i8, Custom); + setOperationAction(ISD::SDIV, MVT::i16, Custom); + setOperationAction(ISD::UDIV, MVT::i16, Custom); + setOperationAction(ISD::UREM, MVT::i16, Custom); + setOperationAction(ISD::SDIV, MVT::i32, Custom); + setOperationAction(ISD::UDIV, MVT::i32, Custom); + setOperationAction(ISD::UREM, MVT::i32, Custom); + } else { + setOperationAction(ISD::MUL, MVT::i64, Custom); + } } setOperationAction(ISD::SDIVREM, XLenVT, Expand); @@ -3925,9 +3928,43 @@ Results.push_back(RCW.getValue(2)); break; } + case ISD::MUL: { + unsigned Size = N->getSimpleValueType(0).getSizeInBits(); + unsigned XLen = Subtarget.getXLen(); + // This multiply needs to be expanded, try to use MULHSU+MUL if possible. + if (Size > XLen) { + assert(Size == (XLen * 2) && "Unexpected custom legalisation"); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); + APInt HighMask = APInt::getHighBitsSet(Size, XLen); + + bool LHSIsU = DAG.MaskedValueIsZero(LHS, HighMask); + bool RHSIsU = DAG.MaskedValueIsZero(RHS, HighMask); + // We need exactly one side to be unsigned. + if (LHSIsU == RHSIsU) + return; + + auto MakeMULPair = [&](SDValue S, SDValue U) { + MVT XLenVT = Subtarget.getXLenVT(); + S = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, S); + U = DAG.getNode(ISD::TRUNCATE, DL, XLenVT, U); + SDValue Lo = DAG.getNode(ISD::MUL, DL, XLenVT, S, U); + SDValue Hi = DAG.getNode(RISCVISD::MULHSU, DL, XLenVT, S, U); + return DAG.getNode(ISD::BUILD_PAIR, DL, N->getValueType(0), Lo, Hi); + }; + + // The other operand should be signed. + if (RHSIsU && DAG.ComputeNumSignBits(LHS) > XLen) + Results.push_back(MakeMULPair(LHS, RHS)); + else if (LHSIsU && DAG.ComputeNumSignBits(RHS) > XLen) + Results.push_back(MakeMULPair(RHS, LHS)); + + return; + } + LLVM_FALLTHROUGH; + } case ISD::ADD: case ISD::SUB: - case ISD::MUL: assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && "Unexpected custom legalisation"); if (N->getOperand(1).getOpcode() == ISD::Constant) @@ -6799,6 +6836,7 @@ NODE_NAME_CASE(BuildPairF64) NODE_NAME_CASE(SplitF64) NODE_NAME_CASE(TAIL) + NODE_NAME_CASE(MULHSU) NODE_NAME_CASE(SLLW) NODE_NAME_CASE(SRAW) NODE_NAME_CASE(SRLW) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoM.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoM.td @@ -15,6 +15,7 @@ // RISC-V specific DAG Nodes. //===----------------------------------------------------------------------===// +def riscv_mulhsu : SDNode<"RISCVISD::MULHSU", SDTIntBinOp>; def riscv_divw : SDNode<"RISCVISD::DIVW", SDT_RISCVIntBinOpW>; def riscv_divuw : SDNode<"RISCVISD::DIVUW", SDT_RISCVIntBinOpW>; def riscv_remuw : SDNode<"RISCVISD::REMUW", SDT_RISCVIntBinOpW>; @@ -63,7 +64,7 @@ def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; -// No ISDOpcode for mulhsu +def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; def : PatGprGpr; diff --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll --- a/llvm/test/CodeGen/RISCV/mul.ll +++ b/llvm/test/CodeGen/RISCV/mul.ll @@ -309,10 +309,7 @@ ; ; RV32IM-LABEL: mulhsu: ; RV32IM: # %bb.0: -; RV32IM-NEXT: srai a2, a1, 31 -; RV32IM-NEXT: mulhu a1, a0, a1 -; RV32IM-NEXT: mul a0, a0, a2 -; RV32IM-NEXT: add a0, a1, a0 +; RV32IM-NEXT: mulhsu a0, a1, a0 ; RV32IM-NEXT: ret ; ; RV64I-LABEL: mulhsu: @@ -1294,10 +1291,7 @@ ; ; RV64IM-LABEL: mulhsu_i64: ; RV64IM: # %bb.0: -; RV64IM-NEXT: srai a2, a1, 63 -; RV64IM-NEXT: mulhu a1, a0, a1 -; RV64IM-NEXT: mul a0, a0, a2 -; RV64IM-NEXT: add a0, a1, a0 +; RV64IM-NEXT: mulhsu a0, a1, a0 ; RV64IM-NEXT: ret %1 = zext i64 %a to i128 %2 = sext i64 %b to i128