diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3330,20 +3330,21 @@ /// Fold /// (-1 u/ x) u< y +/// ((x * y) u/ x) != y /// to /// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit -/// Note that the comparison is commutative, while inverted (u>=) predicate +/// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -static Value * -foldUnsignedMultiplicationOverflowCheck(ICmpInst &I, - InstCombiner::BuilderTy &Builder) { +Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; + Instruction *Mul; bool NeedNegation; // Look for: (-1 u/ x) u= y if (!I.isEquality() && match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), m_Value(Y)))) { + Mul = nullptr; // Canonicalize as-if y was on RHS. if (I.getOperand(1) != Y) Pred = I.getSwappedPredicate(); @@ -3359,12 +3360,33 @@ default: return nullptr; // Wrong predicate. } + } else // Look for: ((x * y) u/ x) !=/== y + if (I.isEquality() && + match(&I, m_c_ICmp(Pred, m_Value(Y), + m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + m_Value(X)), + m_Instruction(Mul)), + m_Deferred(X)))))) { + NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; } else return nullptr; + BuilderTy::InsertPointGuard Guard(Builder); + // If the pattern included (x * y), we'll want to insert new instructions + // right before that original multiplication so that we can replace it. + if (Mul) + Builder.SetInsertPoint(Mul); + Function *F = Intrinsic::getDeclaration( I.getModule(), Intrinsic::umul_with_overflow, X->getType()); CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + + // If the multiplication was used elsewhere, to ensure that we don't leave + // "duplicate" instructions, replace uses of that original multiplication + // with the multiplication result from the with.overflow intrinsic. + if (Mul && !Mul->hasOneUse()) + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); if (NeedNegation) Res = Builder.CreateNot(Res, "umul.not.ov"); @@ -3721,7 +3743,7 @@ } } - if (Value *V = foldUnsignedMultiplicationOverflowCheck(I, Builder)) + if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -841,6 +841,8 @@ Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldICmpWithZero(ICmpInst &Cmp); + Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, diff --git a/llvm/test/Transforms/InstCombine/unsigned-mul-lack-of-overflow-check-via-mul-udiv.ll b/llvm/test/Transforms/InstCombine/unsigned-mul-lack-of-overflow-check-via-mul-udiv.ll --- a/llvm/test/Transforms/InstCombine/unsigned-mul-lack-of-overflow-check-via-mul-udiv.ll +++ b/llvm/test/Transforms/InstCombine/unsigned-mul-lack-of-overflow-check-via-mul-udiv.ll @@ -8,10 +8,10 @@ define i1 @t0_basic(i8 %x, i8 %y) { ; CHECK-LABEL: @t0_basic( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor i1 [[UMUL_OV]], true +; CHECK-NEXT: ret i1 [[UMUL_NOT_OV]] ; %t0 = mul i8 %x, %y %t1 = udiv i8 %t0, %x @@ -21,10 +21,10 @@ define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @t1_vec( -; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv <2 x i8> [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq <2 x i8> [[T1]], [[Y]] -; CHECK-NEXT: ret <2 x i1> [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.umul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor <2 x i1> [[UMUL_OV]], +; CHECK-NEXT: ret <2 x i1> [[UMUL_NOT_OV]] ; %t0 = mul <2 x i8> %x, %y %t1 = udiv <2 x i8> %t0, %x @@ -37,10 +37,10 @@ define i1 @t2_commutative(i8 %x) { ; CHECK-LABEL: @t2_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor i1 [[UMUL_OV]], true +; CHECK-NEXT: ret i1 [[UMUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -52,10 +52,10 @@ define i1 @t3_commutative(i8 %x) { ; CHECK-LABEL: @t3_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor i1 [[UMUL_OV]], true +; CHECK-NEXT: ret i1 [[UMUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -67,10 +67,10 @@ define i1 @t4_commutative(i8 %x) { ; CHECK-LABEL: @t4_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[Y]], [[T1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor i1 [[UMUL_OV]], true +; CHECK-NEXT: ret i1 [[UMUL_NOT_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -85,11 +85,12 @@ define i1 @t5_extrause0(i8 %x, i8 %y) { ; CHECK-LABEL: @t5_extrause0( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: call void @use8(i8 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_VAL:%.*]] = extractvalue { i8, i1 } [[UMUL]], 0 +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: [[UMUL_NOT_OV:%.*]] = xor i1 [[UMUL_OV]], true +; CHECK-NEXT: call void @use8(i8 [[UMUL_VAL]]) +; CHECK-NEXT: ret i1 [[UMUL_NOT_OV]] ; %t0 = mul i8 %x, %y call void @use8(i8 %t0) diff --git a/llvm/test/Transforms/InstCombine/unsigned-mul-overflow-check-via-mul-udiv.ll b/llvm/test/Transforms/InstCombine/unsigned-mul-overflow-check-via-mul-udiv.ll --- a/llvm/test/Transforms/InstCombine/unsigned-mul-overflow-check-via-mul-udiv.ll +++ b/llvm/test/Transforms/InstCombine/unsigned-mul-overflow-check-via-mul-udiv.ll @@ -8,10 +8,9 @@ define i1 @t0_basic(i8 %x, i8 %y) { ; CHECK-LABEL: @t0_basic( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %t0 = mul i8 %x, %y %t1 = udiv i8 %t0, %x @@ -21,10 +20,9 @@ define <2 x i1> @t1_vec(<2 x i8> %x, <2 x i8> %y) { ; CHECK-LABEL: @t1_vec( -; CHECK-NEXT: [[T0:%.*]] = mul <2 x i8> [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv <2 x i8> [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne <2 x i8> [[T1]], [[Y]] -; CHECK-NEXT: ret <2 x i1> [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { <2 x i8>, <2 x i1> } @llvm.umul.with.overflow.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { <2 x i8>, <2 x i1> } [[UMUL]], 1 +; CHECK-NEXT: ret <2 x i1> [[UMUL_OV]] ; %t0 = mul <2 x i8> %x, %y %t1 = udiv <2 x i8> %t0, %x @@ -37,10 +35,9 @@ define i1 @t2_commutative(i8 %x) { ; CHECK-LABEL: @t2_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -52,10 +49,9 @@ define i1 @t3_commutative(i8 %x) { ; CHECK-LABEL: @t3_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -67,10 +63,9 @@ define i1 @t4_commutative(i8 %x) { ; CHECK-LABEL: @t4_commutative( ; CHECK-NEXT: [[Y:%.*]] = call i8 @gen8() -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[Y]], [[X:%.*]] -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[Y]], [[T1]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y]]) +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %y = call i8 @gen8() %t0 = mul i8 %y, %x ; swapped @@ -85,11 +80,11 @@ define i1 @t5_extrause0(i8 %x, i8 %y) { ; CHECK-LABEL: @t5_extrause0( -; CHECK-NEXT: [[T0:%.*]] = mul i8 [[X:%.*]], [[Y:%.*]] -; CHECK-NEXT: call void @use8(i8 [[T0]]) -; CHECK-NEXT: [[T1:%.*]] = udiv i8 [[T0]], [[X]] -; CHECK-NEXT: [[R:%.*]] = icmp ne i8 [[T1]], [[Y]] -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: [[UMUL:%.*]] = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 [[X:%.*]], i8 [[Y:%.*]]) +; CHECK-NEXT: [[UMUL_VAL:%.*]] = extractvalue { i8, i1 } [[UMUL]], 0 +; CHECK-NEXT: [[UMUL_OV:%.*]] = extractvalue { i8, i1 } [[UMUL]], 1 +; CHECK-NEXT: call void @use8(i8 [[UMUL_VAL]]) +; CHECK-NEXT: ret i1 [[UMUL_OV]] ; %t0 = mul i8 %x, %y call void @use8(i8 %t0)