diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -715,6 +715,45 @@ return nullptr; } +// select (x == 0), 0, x * y --> x * freeze(y) +// select (y == 0), 0, x * y --> freeze(x) * y +// Usage of mul instead of 0 will make the result more poisonous, +// so the operand that was not checked in the condition should be frozen. +// The fold is also applied if condition's RHS or select's first argument is +// undef. +static Instruction *foldSelectZeroOrMul(SelectInst &SI, + InstCombiner::BuilderTy &Builder) { + auto *CondVal = SI.getCondition(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + Value *CmpLHS; + Constant *CmpRHS; + ICmpInst::Predicate Predicate; + + if (!match(CondVal, m_ICmp(Predicate, m_Value(CmpLHS), m_Constant(CmpRHS))) || + (Predicate != ICmpInst::ICMP_EQ && Predicate != ICmpInst::ICMP_NE)) + return nullptr; + + if (Predicate == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + auto *TrueValC = dyn_cast(TrueVal); + Value *X, *Y; + if (TrueValC == nullptr || TrueVal->getType() != CmpRHS->getType() || + !match(Constant::mergeUndefsWith(TrueValC, CmpRHS), m_Zero()) || + !match(FalseVal, m_Mul(m_Value(X), m_Value(Y)))) + return nullptr; + + Instruction *Mul; + if (X == CmpLHS) + Mul = BinaryOperator::CreateMul(X, Builder.CreateFreeze(Y)); + else + Mul = BinaryOperator::CreateMul(Builder.CreateFreeze(X), Y); + Mul->copyIRFlags(cast(FalseVal)); + FalseVal->replaceAllUsesWith(Mul); + return Mul; +} + /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. @@ -2917,6 +2956,8 @@ return Add; if (Instruction *Or = foldSetClearBits(SI, Builder)) return Or; + if (Instruction *Mul = foldSelectZeroOrMul(SI, Builder)) + return Mul; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast(TrueVal); diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll --- a/llvm/test/Transforms/InstCombine/select.ll +++ b/llvm/test/Transforms/InstCombine/select.ll @@ -2847,3 +2847,124 @@ declare void @use(i1) declare void @use_i8(i8) declare i32 @llvm.cttz.i32(i32, i1 immarg) + +define i32 @mul_select_eq_zero(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_zero( +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp eq i32 %x, 0 + %m = mul i32 %x, %y + %r = select i1 %c, i32 0, i32 %m + ret i32 %r +} + +define i32 @mul_select_eq_zero_commute(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_zero_commute( +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[X:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul i32 [[TMP1]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp eq i32 %y, 0 + %m = mul i32 %x, %y + %r = select i1 %c, i32 0, i32 %m + ret i32 %r +} + +define i32 @mul_select_eq_zero_copy_flags(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_zero_copy_flags( +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul nuw nsw i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp eq i32 %x, 0 + %m = mul nuw nsw i32 %x, %y + %r = select i1 %c, i32 0, i32 %m + ret i32 %r +} + +define i32 @mul_select_ne_zero(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_ne_zero( +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul i32 [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp ne i32 %x, 0 + %m = mul i32 %x, %y + %r = select i1 %c, i32 %m, i32 0 + ret i32 %r +} + +define i32 @mul_select_eq_undef(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_undef( +; CHECK-NEXT: ret i32 0 +; + %c = icmp eq i32 %x, undef + %m = mul i32 %x, %y + %r = select i1 %c, i32 0, i32 %m + ret i32 %r +} + +define i32 @mul_select_eq_zero_sel_undef(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_zero_sel_undef( +; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 +; CHECK-NEXT: [[M:%.*]] = mul i32 [[X]], [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 undef, i32 [[M]] +; CHECK-NEXT: ret i32 [[R]] +; + %c = icmp eq i32 %x, 0 + %m = mul i32 %x, %y + %r = select i1 %c, i32 undef, i32 %m + ret i32 %r +} + +define i32 @mul_select_eq_zero_multiple_users(i32 %x, i32 %y) { +; CHECK-LABEL: @mul_select_eq_zero_multiple_users( +; CHECK-NEXT: [[TMP1:%.*]] = freeze i32 [[Y]] +; CHECK-NEXT: [[R:%.*]] = mul i32 [[TMP1]], [[X]] +; CHECK-NEXT: [[P:%.*]] = mul i32 [[R]], [[R]] +; CHECK-NEXT: ret i32 [[P]] +; + %c = icmp eq i32 %x, 0 + %m = mul i32 %x, %y + %r = select i1 %c, i32 0, i32 %m + %p = mul i32 %m, %r + ret i32 %p +} + +define <4 x i32> @mul_select_eq_zero_vector(<4 x i32> %x, <4 x i32> %y) { +; CHECK-LABEL: @mul_select_eq_zero_vector( +; CHECK-NEXT: [[TMP1:%.*]] = freeze <4 x i32> [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul <4 x i32> [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret <4 x i32> [[R]] +; + %c = icmp eq <4 x i32> %x, zeroinitializer + %m = mul <4 x i32> %x, %y + %r = select <4 x i1> %c, <4 x i32> zeroinitializer, <4 x i32> %m + ret <4 x i32> %r +} + +define <2 x i32> @mul_select_eq_undef_vector(<2 x i32> %x, <2 x i32> %y) { +; CHECK-LABEL: @mul_select_eq_undef_vector( +; CHECK-NEXT: [[TMP1:%.*]] = freeze <2 x i32> [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul <2 x i32> [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret <2 x i32> [[R]] +; + %c = icmp eq <2 x i32> %x, + %m = mul <2 x i32> %x, %y + %r = select <2 x i1> %c, <2 x i32> , <2 x i32> %m + ret <2 x i32> %r +} + +define <2 x i32> @mul_select_eq_zero_sel_undef_vector(<2 x i32> %x, <2 x i32> %y) { +; CHECK-LABEL: @mul_select_eq_zero_sel_undef_vector( +; CHECK-NEXT: [[TMP1:%.*]] = freeze <2 x i32> [[Y:%.*]] +; CHECK-NEXT: [[R:%.*]] = mul <2 x i32> [[TMP1]], [[X:%.*]] +; CHECK-NEXT: ret <2 x i32> [[R]] +; + %c = icmp eq <2 x i32> %x, zeroinitializer + %m = mul <2 x i32> %x, %y + %r = select <2 x i1> %c, <2 x i32> , <2 x i32> %m + ret <2 x i32> %r +}