diff --git a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp @@ -148,8 +148,14 @@ /// \returns True. bool promoteUniformBitreverseToI32(IntrinsicInst &I) const; - + /// \returns The minimum number of bits needed to store the value of \Op as an + /// unsigned integer. Truncating to this size and then zero-extending to + /// ScalarSize will not change the value. unsigned numBitsUnsigned(Value *Op, unsigned ScalarSize) const; + + /// \returns The minimum number of bits needed to store the value of \Op as a + /// signed integer. Truncating to this size and then sign-extending to + /// ScalarSize will not change the value. unsigned numBitsSigned(Value *Op, unsigned ScalarSize) const; /// Replace mul instructions with llvm.amdgcn.mul.u24 or llvm.amdgcn.mul.s24. @@ -449,7 +455,7 @@ unsigned ScalarSize) const { // In order for this to be a signed 24-bit value, bit 23, must // be a sign bit. - return ScalarSize - ComputeNumSignBits(Op, *DL, 0, AC); + return ScalarSize - ComputeNumSignBits(Op, *DL, 0, AC) + 1; } static void extractValues(IRBuilder<> &Builder, @@ -482,13 +488,13 @@ // width of the original destination. static Value *getMul24(IRBuilder<> &Builder, Value *LHS, Value *RHS, unsigned Size, unsigned NumBits, bool IsSigned) { - if (Size <= 32 || (IsSigned ? NumBits <= 30 : NumBits <= 32)) { + if (Size <= 32 || NumBits <= 32) { Intrinsic::ID ID = IsSigned ? Intrinsic::amdgcn_mul_i24 : Intrinsic::amdgcn_mul_u24; return Builder.CreateIntrinsic(ID, {}, {LHS, RHS}); } - assert(IsSigned ? NumBits <= 46 : NumBits <= 48); + assert(NumBits <= 48); Intrinsic::ID LoID = IsSigned ? Intrinsic::amdgcn_mul_i24 : Intrinsic::amdgcn_mul_u24; @@ -530,9 +536,8 @@ (RHSBits = numBitsUnsigned(RHS, Size)) <= 24) { IsSigned = false; - } else if (ST->hasMulI24() && - (LHSBits = numBitsSigned(LHS, Size)) < 24 && - (RHSBits = numBitsSigned(RHS, Size)) < 24) { + } else if (ST->hasMulI24() && (LHSBits = numBitsSigned(LHS, Size)) <= 24 && + (RHSBits = numBitsSigned(RHS, Size)) <= 24) { IsSigned = true; } else diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.h @@ -35,7 +35,14 @@ SDValue getFFBX_U32(SelectionDAG &DAG, SDValue Op, const SDLoc &DL, unsigned Opc) const; public: + /// \returns The minimum number of bits needed to store the value of \Op as an + /// unsigned integer. Truncating to this size and then zero-extending to the + /// original size will not change the value. static unsigned numBitsUnsigned(SDValue Op, SelectionDAG &DAG); + + /// \returns The minimum number of bits needed to store the value of \Op as a + /// signed integer. Truncating to this size and then sign-extending to the + /// original size will not change the value. static unsigned numBitsSigned(SDValue Op, SelectionDAG &DAG); protected: diff --git a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp @@ -53,7 +53,7 @@ // In order for this to be a signed 24-bit value, bit 23, must // be a sign bit. - return VT.getSizeInBits() - DAG.ComputeNumSignBits(Op); + return VT.getSizeInBits() - DAG.ComputeNumSignBits(Op) + 1; } AMDGPUTargetLowering::AMDGPUTargetLowering(const TargetMachine &TM, @@ -2875,7 +2875,7 @@ EVT VT = Op.getValueType(); return VT.getSizeInBits() >= 24 && // Types less than 24-bit should be treated // as unsigned 24-bit values. - AMDGPUTargetLowering::numBitsSigned(Op, DAG) < 24; + AMDGPUTargetLowering::numBitsSigned(Op, DAG) <= 24; } static SDValue simplifyMul24(SDNode *Node24, diff --git a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp --- a/llvm/lib/Target/AMDGPU/SIISelLowering.cpp +++ b/llvm/lib/Target/AMDGPU/SIISelLowering.cpp @@ -10464,7 +10464,7 @@ return getMad64_32(DAG, SL, VT, MulLHS, MulRHS, AddRHS, false); } - if (numBitsSigned(MulLHS, DAG) < 32 && numBitsSigned(MulRHS, DAG) < 32) { + if (numBitsSigned(MulLHS, DAG) <= 32 && numBitsSigned(MulRHS, DAG) <= 32) { MulLHS = DAG.getSExtOrTrunc(MulLHS, SL, MVT::i32); MulRHS = DAG.getSExtOrTrunc(MulRHS, SL, MVT::i32); AddRHS = DAG.getSExtOrTrunc(AddRHS, SL, MVT::i64);