diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.h b/llvm/lib/Target/AMDGPU/SIISelLowering.h --- a/llvm/lib/Target/AMDGPU/SIISelLowering.h +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.h @@ -197,6 +197,7 @@ SDValue reassociateScalarOps(SDNode *N, SelectionDAG &DAG) const; unsigned getFusedOpcode(const SelectionDAG &DAG, const SDNode *N0, const SDNode *N1) const; + SDValue tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performAddCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performAddCarrySubCarryCombine(SDNode *N, DAGCombinerInfo &DCI) const; SDValue performSubCombine(SDNode *N, DAGCombinerInfo &DCI) const; diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10661,39 +10661,64 @@ return DAG.getNode(ISD::TRUNCATE, SL, VT, Mad); } -SDValue SITargetLowering::performAddCombine(SDNode *N, +// Fold (add (mul x, y), z) --> (mad_[iu]64_[iu]32 x, y, z). +SDValue SITargetLowering::tryFoldToMad64_32(SDNode *N, DAGCombinerInfo &DCI) const { + assert(N->getOpcode() == ISD::ADD); + SelectionDAG &DAG = DCI.DAG; EVT VT = N->getValueType(0); SDLoc SL(N); SDValue LHS = N->getOperand(0); SDValue RHS = N->getOperand(1); - if ((LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL) - && Subtarget->hasMad64_32() && - !VT.isVector() && VT.getScalarSizeInBits() > 32 && - VT.getScalarSizeInBits() <= 64) { - if (LHS.getOpcode() != ISD::MUL) - std::swap(LHS, RHS); + if (VT.isVector()) + return SDValue(); - SDValue MulLHS = LHS.getOperand(0); - SDValue MulRHS = LHS.getOperand(1); - SDValue AddRHS = RHS; + unsigned NumBits = VT.getScalarSizeInBits(); + if (NumBits <= 32 || NumBits > 64) + return SDValue(); - // TODO: Maybe restrict if SGPR inputs. - if (numBitsUnsigned(MulLHS, DAG) <= 32 && - numBitsUnsigned(MulRHS, DAG) <= 32) { - MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32); - MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32); - AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64); - return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false); - } + if (LHS.getOpcode() != ISD::MUL) { + assert(RHS.getOpcode() == ISD::MUL); + std::swap(LHS, RHS); + } + + SDValue MulLHS = LHS.getOperand(0); + SDValue MulRHS = LHS.getOperand(1); + SDValue AddRHS = RHS; + + // TODO: Maybe restrict if SGPR inputs. + if (numBitsUnsigned(MulLHS, DAG) <= 32 && + numBitsUnsigned(MulRHS, DAG) <= 32) { + MulLHS = DAG.getZExtOrTrunc(MulLHS, SL, MVT::i32); + MulRHS = DAG.getZExtOrTrunc(MulRHS, SL, MVT::i32); + AddRHS = DAG.getZExtOrTrunc(AddRHS, SL, MVT::i64); + return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false); + } + + if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) { + MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32); + MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32); + AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64); + return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true); + } + + return SDValue(); +} + +SDValue SITargetLowering::performAddCombine(SDNode *N, + DAGCombinerInfo &DCI) const { + SelectionDAG &DAG = DCI.DAG; + EVT VT = N->getValueType(0); + SDLoc SL(N); + SDValue LHS = N->getOperand(0); + SDValue RHS = N->getOperand(1); - if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) { - MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32); - MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32); - AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64); - return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, true); + if (LHS.getOpcode() == ISD::MUL || RHS.getOpcode() == ISD::MUL) { + if (Subtarget->hasMad64_32()) { + if (SDValue Folded = tryFoldToMad64_32(N, DCI)) + return Folded; } return SDValue();