diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -4450,6 +4450,11 @@ return DAG.getConstant(0, DL, VT); } + // fold (mulhs c1, c2) + if (SDValue C = + DAG.FoldConstantArithmetic(ISD::MULHS, SDLoc(N), VT, {N0, N1})) + return C; + // fold (mulhs x, 0) -> 0 if (isNullConstant(N1)) return N1; @@ -4498,6 +4503,11 @@ return DAG.getConstant(0, DL, VT); } + // fold (mulhu c1, c2) + if (SDValue C = + DAG.FoldConstantArithmetic(ISD::MULHU, SDLoc(N), VT, {N0, N1})) + return C; + // fold (mulhu x, 0) -> 0 if (isNullConstant(N1)) return N1; diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp @@ -5085,6 +5085,18 @@ if (!C2.getBoolValue()) break; return C1.srem(C2); + case ISD::MULHS: { + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.sextOrTrunc(FullWidth); + APInt C2Ext = C2.sextOrTrunc(FullWidth); + return (C1Ext * C2Ext).getHiBits(C1.getBitWidth()).trunc(C1.getBitWidth()); + } + case ISD::MULHU: { + unsigned FullWidth = C1.getBitWidth() * 2; + APInt C1Ext = C1.zextOrTrunc(FullWidth); + APInt C2Ext = C2.zextOrTrunc(FullWidth); + return (C1Ext * C2Ext).getHiBits(C1.getBitWidth()).trunc(C1.getBitWidth()); + } } return llvm::None; }