Index: llvm/include/llvm/Transforms/InstCombine/InstCombiner.h =================================================================== --- llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -480,6 +480,15 @@ return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } + /// Get the minimum bit size for this Value \p Op as a signed integer. + /// i.e. The minimum bitwidth where the sign of the value is known to remain + /// the same. + unsigned ComputeMinSignedBits(const Value *Op, unsigned Depth = 0, + const Instruction *CxtI = nullptr) const { + unsigned SignedBits = ComputeNumSignBits(Op, Depth, CxtI); + return Op->getType()->getScalarSizeInBits() - SignedBits + 1; + } + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { Index: llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1270,9 +1270,8 @@ // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + if (IC.ComputeMinSignedBits(A, 0, &I) > NewWidth || + IC.ComputeMinSignedBits(B, 0, &I) > NewWidth) return nullptr; // In order to replace the original add with a narrower Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2304,16 +2304,6 @@ // Create the new type (which can be a vector type) Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); - // Match the two extends from the add/sub - Value *A, *B; - if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B))))) - return nullptr; - // And check the incoming values are of a type smaller than or equal to the - // size of the saturation. Otherwise the higher bits can cause different - // results. - if (A->getType()->getScalarSizeInBits() > NewBitWidth || - B->getType()->getScalarSizeInBits() > NewBitWidth) - return nullptr; Intrinsic::ID IntrinsicID; if (AddSub->getOpcode() == Instruction::Add) @@ -2323,10 +2313,16 @@ else return nullptr; + // The two operands of the add/sub must be nsw-truncatable to the NewTy. This + // is usually achieved via a sext from a smaller type. + if (ComputeMinSignedBits(AddSub->getOperand(0), 0, AddSub) > NewBitWidth || + ComputeMinSignedBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) + return nullptr; + // Finally create and return the sat intrinsic, truncated to the new type Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); - Value *AT = Builder.CreateSExt(A, NewTy); - Value *BT = Builder.CreateSExt(B, NewTy); + Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); + Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); Value *Sat = Builder.CreateCall(F, {AT, BT}); return CastInst::Create(Instruction::SExt, Sat, Ty); } Index: llvm/test/Transforms/InstCombine/sadd_sat.ll =================================================================== --- llvm/test/Transforms/InstCombine/sadd_sat.ll +++ llvm/test/Transforms/InstCombine/sadd_sat.ll @@ -698,13 +698,10 @@ define i32 @ashrA(i64 %a, i32 %b) { ; CHECK-LABEL: @ashrA( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 32 -; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]] -; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = call i64 @llvm.smin.i64(i64 [[ADD]], i64 2147483647) -; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = call i64 @llvm.smax.i64(i64 [[SPEC_STORE_SELECT]], i64 -2147483648) -; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32 -; CHECK-NEXT: ret i32 [[CONV7]] +; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[B:%.*]]) +; CHECK-NEXT: ret i32 [[TMP2]] ; entry: %conv = ashr i64 %a, 32 @@ -719,15 +716,10 @@ define i32 @ashrB(i32 %a, i64 %b) { ; CHECK-LABEL: @ashrB( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = sext i32 [[A:%.*]] to i64 -; CHECK-NEXT: [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]] -; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 -; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648 -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647 -; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647 -; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32 -; CHECK-NEXT: ret i32 [[CONV7]] +; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[B:%.*]], 32 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP2:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP1]], i32 [[A:%.*]]) +; CHECK-NEXT: ret i32 [[TMP2]] ; entry: %conv = sext i32 %a to i64 @@ -744,15 +736,12 @@ define i32 @ashrAB(i64 %a, i64 %b) { ; CHECK-LABEL: @ashrAB( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 32 -; CHECK-NEXT: [[CONV1:%.*]] = ashr i64 [[B:%.*]], 32 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV1]], [[CONV]] -; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 -; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648 -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647 -; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647 -; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32 -; CHECK-NEXT: ret i32 [[CONV7]] +; CHECK-NEXT: [[TMP0:%.*]] = lshr i64 [[A:%.*]], 32 +; CHECK-NEXT: [[TMP1:%.*]] = lshr i64 [[B:%.*]], 32 +; CHECK-NEXT: [[TMP2:%.*]] = trunc i64 [[TMP1]] to i32 +; CHECK-NEXT: [[TMP3:%.*]] = trunc i64 [[TMP0]] to i32 +; CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP2]], i32 [[TMP3]]) +; CHECK-NEXT: ret i32 [[TMP4]] ; entry: %conv = ashr i64 %a, 32 @@ -795,14 +784,9 @@ ; CHECK-LABEL: @ashrA33( ; CHECK-NEXT: entry: ; CHECK-NEXT: [[CONV:%.*]] = ashr i64 [[A:%.*]], 33 -; CHECK-NEXT: [[CONV1:%.*]] = sext i32 [[B:%.*]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[CONV]], [[CONV1]] -; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt i64 [[ADD]], -2147483648 -; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select i1 [[TMP0]], i64 [[ADD]], i64 -2147483648 -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt i64 [[SPEC_STORE_SELECT]], 2147483647 -; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select i1 [[TMP1]], i64 [[SPEC_STORE_SELECT]], i64 2147483647 -; CHECK-NEXT: [[CONV7:%.*]] = trunc i64 [[SPEC_STORE_SELECT8]] to i32 -; CHECK-NEXT: ret i32 [[CONV7]] +; CHECK-NEXT: [[TMP0:%.*]] = trunc i64 [[CONV]] to i32 +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[TMP0]], i32 [[B:%.*]]) +; CHECK-NEXT: ret i32 [[TMP1]] ; entry: %conv = ashr i64 %a, 33 @@ -844,15 +828,10 @@ define <2 x i8> @ashrv2i8_s(<2 x i16> %a, <2 x i8> %b) { ; CHECK-LABEL: @ashrv2i8_s( ; CHECK-NEXT: entry: -; CHECK-NEXT: [[CONV:%.*]] = ashr <2 x i16> [[A:%.*]], -; CHECK-NEXT: [[CONV1:%.*]] = sext <2 x i8> [[B:%.*]] to <2 x i16> -; CHECK-NEXT: [[ADD:%.*]] = add nsw <2 x i16> [[CONV]], [[CONV1]] -; CHECK-NEXT: [[TMP0:%.*]] = icmp sgt <2 x i16> [[ADD]], -; CHECK-NEXT: [[SPEC_STORE_SELECT:%.*]] = select <2 x i1> [[TMP0]], <2 x i16> [[ADD]], <2 x i16> -; CHECK-NEXT: [[TMP1:%.*]] = icmp slt <2 x i16> [[SPEC_STORE_SELECT]], -; CHECK-NEXT: [[SPEC_STORE_SELECT8:%.*]] = select <2 x i1> [[TMP1]], <2 x i16> [[SPEC_STORE_SELECT]], <2 x i16> -; CHECK-NEXT: [[CONV7:%.*]] = trunc <2 x i16> [[SPEC_STORE_SELECT8]] to <2 x i8> -; CHECK-NEXT: ret <2 x i8> [[CONV7]] +; CHECK-NEXT: [[TMP0:%.*]] = lshr <2 x i16> [[A:%.*]], +; CHECK-NEXT: [[TMP1:%.*]] = trunc <2 x i16> [[TMP0]] to <2 x i8> +; CHECK-NEXT: [[TMP2:%.*]] = call <2 x i8> @llvm.sadd.sat.v2i8(<2 x i8> [[TMP1]], <2 x i8> [[B:%.*]]) +; CHECK-NEXT: ret <2 x i8> [[TMP2]] ; entry: %conv = ashr <2 x i16> %a, @@ -868,13 +847,10 @@ define i16 @or(i8 %X, i16 %Y) { ; CHECK-LABEL: @or( -; CHECK-NEXT: [[CONV10:%.*]] = sext i8 [[X:%.*]] to i16 -; CHECK-NEXT: [[CONV14:%.*]] = or i16 [[Y:%.*]], -16 -; CHECK-NEXT: [[SUB:%.*]] = sub nsw i16 [[CONV10]], [[CONV14]] -; CHECK-NEXT: [[L9:%.*]] = icmp sgt i16 [[SUB]], -128 -; CHECK-NEXT: [[L10:%.*]] = select i1 [[L9]], i16 [[SUB]], i16 -128 -; CHECK-NEXT: [[L11:%.*]] = icmp slt i16 [[L10]], 127 -; CHECK-NEXT: [[L12:%.*]] = select i1 [[L11]], i16 [[L10]], i16 127 +; CHECK-NEXT: [[TMP1:%.*]] = trunc i16 [[Y:%.*]] to i8 +; CHECK-NEXT: [[TMP2:%.*]] = or i8 [[TMP1]], -16 +; CHECK-NEXT: [[TMP3:%.*]] = call i8 @llvm.ssub.sat.i8(i8 [[X:%.*]], i8 [[TMP2]]) +; CHECK-NEXT: [[L12:%.*]] = sext i8 [[TMP3]] to i16 ; CHECK-NEXT: ret i16 [[L12]] ; %conv10 = sext i8 %X to i16