diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -723,6 +723,54 @@ return nullptr; } +// select (x == 0), 0, x * y --> freeze(y) * x +// select (y == 0), 0, x * y --> freeze(x) * y +// select (x == 0), undef, x * y --> freeze(y) * x +// select (x == undef), 0, x * y --> freeze(y) * x +// Usage of mul instead of 0 will make the result more poisonous, +// so the operand that was not checked in the condition should be frozen. +// The latter folding is applied only when a constant compared with x is +// is a vector consisting of 0 and undefs. If a constant compared with x +// is a scalar undefined value or undefined vector then an expression +// should be already folded into a constant. +static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { + auto *CondVal = SI.getCondition(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + Value *X, *Y; + ICmpInst::Predicate Predicate; + + // Assuming that constant compared with zero is not undef (but it may be + // a vector with some undef elements). Otherwise (when a constant is undef) + // the select expression should be already simplified. + if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) || + !ICmpInst::isEquality(Predicate)) + return nullptr; + + if (Predicate == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Check that TrueVal is a constant instead of matching it with m_Zero() + // to handle the case when it is a scalar undef value. + auto *TrueValC = dyn_cast(TrueVal); + if (TrueValC == nullptr || + !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y)))) + return nullptr; + + auto *ZeroC = cast(cast(CondVal)->getOperand(1)); + auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC); + // If X is compared with 0 then TrueVal could be either zero or undef. + // m_Zero match vectors containing some undef elements, but for scalars + // m_Undef should be used explicitly. + if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef())) + return nullptr; + + auto *FalseValI = cast(FalseVal); + auto *FrY = new FreezeInst(Y, Y->getName() + ".fr", FalseValI); + IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); + return IC.replaceInstUsesWith(SI, FalseValI); +} + /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. @@ -2930,6 +2978,8 @@ return Add; if (Instruction *Or = foldSetClearBits(SI, Builder)) return Or; + if (Instruction *Mul = foldSelectZeroOrMul(SI, *this)) + return Mul; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast(TrueVal); diff --git a/llvm/test/Transforms/InstCombine/select.ll b/llvm/test/Transforms/InstCombine/select.ll --- a/llvm/test/Transforms/InstCombine/select.ll +++ b/llvm/test/Transforms/InstCombine/select.ll @@ -2844,12 +2844,12 @@ ret <2 x i1> %r } +; select (x == 0), 0, x * y --> freeze(y) * x define i32 @mul_select_eq_zero(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_eq_zero( -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[M:%.*]] = mul i32 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]] -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[M]] ; %c = icmp eq i32 %x, 0 %m = mul i32 %x, %y @@ -2857,12 +2857,12 @@ ret i32 %r } +; select (y == 0), 0, x * y --> freeze(x) * y define i32 @mul_select_eq_zero_commute(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_eq_zero_commute( -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[Y:%.*]], 0 -; CHECK-NEXT: [[M:%.*]] = mul i32 [[X:%.*]], [[Y]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]] -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[X_FR:%.*]] = freeze i32 [[X:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul i32 [[X_FR]], [[Y:%.*]] +; CHECK-NEXT: ret i32 [[M]] ; %c = icmp eq i32 %y, 0 %m = mul i32 %x, %y @@ -2870,12 +2870,12 @@ ret i32 %r } +; Check that mul's flags preserved during the transformation. define i32 @mul_select_eq_zero_copy_flags(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_eq_zero_copy_flags( -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[M:%.*]] = mul nuw nsw i32 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]] -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul nuw nsw i32 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[M]] ; %c = icmp eq i32 %x, 0 %m = mul nuw nsw i32 %x, %y @@ -2883,12 +2883,13 @@ ret i32 %r } +; Check that the transformation could be applied after condition's inversion. +; select (x != 0), x * y, 0 --> freeze(y) * x define i32 @mul_select_ne_zero(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_ne_zero( -; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[M:%.*]] = mul i32 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[C_NOT]], i32 0, i32 [[M]] -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[M]] ; %c = icmp ne i32 %x, 0 %m = mul i32 %x, %y @@ -2896,12 +2897,14 @@ ret i32 %r } +; Check that if one of a select's branches returns undef then +; an expression could be folded into mul as if there was a 0 instead of undef. +; select (x == 0), undef, x * y --> freeze(y) * x define i32 @mul_select_eq_zero_sel_undef(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_eq_zero_sel_undef( -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 0 -; CHECK-NEXT: [[M:%.*]] = mul i32 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 undef, i32 [[M]] -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret i32 [[M]] ; %c = icmp eq i32 %x, 0 %m = mul i32 %x, %y @@ -2909,15 +2912,16 @@ ret i32 %r } +; Check that the transformation is applied disregard to a number +; of expression's users. define i32 @mul_select_eq_zero_multiple_users(i32 %x, i32 %y) { ; CHECK-LABEL: @mul_select_eq_zero_multiple_users( -; CHECK-NEXT: [[M:%.*]] = mul i32 [[X:%.*]], [[Y:%.*]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze i32 [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul i32 [[Y_FR]], [[X:%.*]] ; CHECK-NEXT: call void @use_i32(i32 [[M]]) -; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[C]], i32 0, i32 [[M]] ; CHECK-NEXT: call void @use_i32(i32 [[M]]) -; CHECK-NEXT: call void @use_i32(i32 [[R]]) -; CHECK-NEXT: ret i32 [[R]] +; CHECK-NEXT: call void @use_i32(i32 [[M]]) +; CHECK-NEXT: ret i32 [[M]] ; %m = mul i32 %x, %y call void @use_i32(i32 %m) @@ -2928,6 +2932,8 @@ ret i32 %r } +; Negative test: select's condition is unrelated to multiplied values, +; so the transformation should not be applied. define i32 @mul_select_eq_zero_unrelated_condition(i32 %x, i32 %y, i32 %z) { ; CHECK-LABEL: @mul_select_eq_zero_unrelated_condition( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[Z:%.*]], 0 @@ -2941,12 +2947,12 @@ ret i32 %r } +; select ( x == 0), 0, x * y --> freeze(y) * x define <4 x i32> @mul_select_eq_zero_vector(<4 x i32> %x, <4 x i32> %y) { ; CHECK-LABEL: @mul_select_eq_zero_vector( -; CHECK-NEXT: [[C:%.*]] = icmp eq <4 x i32> [[X:%.*]], zeroinitializer -; CHECK-NEXT: [[M:%.*]] = mul <4 x i32> [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select <4 x i1> [[C]], <4 x i32> zeroinitializer, <4 x i32> [[M]] -; CHECK-NEXT: ret <4 x i32> [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze <4 x i32> [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul <4 x i32> [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret <4 x i32> [[M]] ; %c = icmp eq <4 x i32> %x, zeroinitializer %m = mul <4 x i32> %x, %y @@ -2954,12 +2960,14 @@ ret <4 x i32> %r } +; Check that a select is folded into multiplication if condition's operand +; is a vector consisting of zeros and undefs. +; select ( x == {0, undef, ...}), 0, x * y --> freeze(y) * x define <2 x i32> @mul_select_eq_undef_vector(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @mul_select_eq_undef_vector( -; CHECK-NEXT: [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]], -; CHECK-NEXT: [[M:%.*]] = mul <2 x i32> [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> , <2 x i32> [[M]] -; CHECK-NEXT: ret <2 x i32> [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze <2 x i32> [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul <2 x i32> [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret <2 x i32> [[M]] ; %c = icmp eq <2 x i32> %x, %m = mul <2 x i32> %x, %y @@ -2967,12 +2975,14 @@ ret <2 x i32> %r } +; Check that a select is folded into multiplication if other select's operand +; is a vector consisting of zeros and undefs. +; select ( x == 0), {0, undef, ...}, x * y --> freeze(y) * x define <2 x i32> @mul_select_eq_zero_sel_undef_vector(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @mul_select_eq_zero_sel_undef_vector( -; CHECK-NEXT: [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]], zeroinitializer -; CHECK-NEXT: [[M:%.*]] = mul <2 x i32> [[X]], [[Y:%.*]] -; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[C]], <2 x i32> , <2 x i32> [[M]] -; CHECK-NEXT: ret <2 x i32> [[R]] +; CHECK-NEXT: [[Y_FR:%.*]] = freeze <2 x i32> [[Y:%.*]] +; CHECK-NEXT: [[M:%.*]] = mul <2 x i32> [[Y_FR]], [[X:%.*]] +; CHECK-NEXT: ret <2 x i32> [[M]] ; %c = icmp eq <2 x i32> %x, zeroinitializer %m = mul <2 x i32> %x, %y @@ -2980,6 +2990,8 @@ ret <2 x i32> %r } +; Negative test: select should not be folded into mul because +; condition's operand and select's operand do not merge into zero vector. define <2 x i32> @mul_select_eq_undef_vector_not_merging_to_zero(<2 x i32> %x, <2 x i32> %y) { ; CHECK-LABEL: @mul_select_eq_undef_vector_not_merging_to_zero( ; CHECK-NEXT: [[C:%.*]] = icmp eq <2 x i32> [[X:%.*]],