diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1618,6 +1618,9 @@ if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return Changed ? &I : nullptr; } @@ -2466,6 +2469,9 @@ } } + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return TryToNarrowDeduceFlags(); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -454,6 +454,12 @@ // -> (BinOp (logic_shift (BinOp X, Y)), Mask) Instruction *foldBinOpShiftWithShift(BinaryOperator &I); + /// Tries to simplify binops of select and cast of the select condition. + /// + /// (Binop (cast C), (select C, T, F)) + /// -> (select C, C0, C1) + Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorizationFolds(BinaryOperator &I); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -474,6 +474,9 @@ if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + // min(X, Y) * max(X, Y) => X * Y. if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), m_c_SMin(m_Deferred(X), m_Deferred(Y))), diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -870,6 +870,71 @@ return MatchBinOp(1); } +// (Binop (zext C), (select C, T, F)) +// -> (select C, (binop 1, T), (binop 0, F)) +// +// (Binop (sext C), (select C, T, F)) +// -> (select C, (binop -1, T), (binop 0, F)) +// +// Attempt to simplify binary operations into a select with folded args, when +// one operand of the binop is a select instruction and the other operand is a +// zext/sext extension, whose value is the select condition. +Instruction * +InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { + // TODO: this simplification may be extended to any speculatable instruction, + // not just binops, and would possibly be handled better in FoldOpIntoSelect. + Instruction::BinaryOps Opc = I.getOpcode(); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *CondVal, *TrueVal, *FalseVal; + Value *CastOp; + + auto MatchSelectAndCast = [&](Value *CastOp, Value *SelectOp) { + return match(CastOp, m_ZExtOrSExt(m_Value(A))) && + A->getType()->getScalarSizeInBits() == 1 && + match(SelectOp, m_Select(m_Value(CondVal), m_Value(TrueVal), + m_Value(FalseVal))); + }; + + // Make sure one side of the binop is a select instruction, and the other is a + // zero/sign extension operating on a i1. + if (MatchSelectAndCast(LHS, RHS)) + CastOp = LHS; + else if (MatchSelectAndCast(RHS, LHS)) + CastOp = RHS; + else + return nullptr; + + auto NewFoldedConst = [&](bool IsTrueArm, Value *V) { + bool IsCastOpRHS = (CastOp == RHS); + bool IsZExt = isa(CastOp); + Constant *C; + + if (IsTrueArm) { + C = Constant::getNullValue(V->getType()); + } else if (IsZExt) { + C = Constant::getIntegerValue( + V->getType(), APInt(V->getType()->getIntegerBitWidth(), 1)); + } else { + C = Constant::getAllOnesValue(V->getType()); + } + + return IsCastOpRHS ? Builder.CreateBinOp(Opc, V, C) + : Builder.CreateBinOp(Opc, C, V); + }; + + // If the value used in the zext/sext is the select condition, or the negated + // of the select condition, the binop can be simplified. + if (CondVal == A) + return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal), + NewFoldedConst(true, FalseVal)); + + if (match(A, m_Not(m_Specific(CondVal)))) + return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal), + NewFoldedConst(false, FalseVal)); + + return nullptr; +} + Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast(LHS); diff --git a/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll b/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll --- a/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll +++ b/llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll @@ -4,9 +4,7 @@ define i64 @add_select_zext(i1 %c) { ; CHECK-LABEL: define i64 @add_select_zext ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 -; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 65, i64 1 ; CHECK-NEXT: ret i64 [[ADD]] ; %sel = select i1 %c, i64 64, i64 1 @@ -18,9 +16,7 @@ define i64 @add_select_sext(i1 %c) { ; CHECK-LABEL: define i64 @add_select_sext ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 -; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 63, i64 1 ; CHECK-NEXT: ret i64 [[ADD]] ; %sel = select i1 %c, i64 64, i64 1 @@ -32,10 +28,7 @@ define i64 @add_select_not_zext(i1 %c) { ; CHECK-LABEL: define i64 @add_select_not_zext ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 -; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C]], true -; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[NOT_C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 2 ; CHECK-NEXT: ret i64 [[ADD]] ; %sel = select i1 %c, i64 64, i64 1 @@ -48,10 +41,7 @@ define i64 @add_select_not_sext(i1 %c) { ; CHECK-LABEL: define i64 @add_select_not_sext ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 -; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C]], true -; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[NOT_C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 0 ; CHECK-NEXT: ret i64 [[ADD]] ; %sel = select i1 %c, i64 64, i64 1 @@ -64,9 +54,7 @@ define i64 @sub_select_sext(i1 %c, i64 %arg) { ; CHECK-LABEL: define i64 @sub_select_sext ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 [[ARG]] -; CHECK-NEXT: [[EXT_NEG:%.*]] = zext i1 [[C]] to i64 -; CHECK-NEXT: [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 65, i64 [[ARG]] ; CHECK-NEXT: ret i64 [[SUB]] ; %sel = select i1 %c, i64 64, i64 %arg @@ -78,10 +66,7 @@ define i64 @sub_select_not_zext(i1 %c, i64 %arg) { ; CHECK-LABEL: define i64 @sub_select_not_zext ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 [[ARG]], i64 64 -; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C]], true -; CHECK-NEXT: [[EXT_NEG:%.*]] = sext i1 [[NOT_C]] to i64 -; CHECK-NEXT: [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 63 ; CHECK-NEXT: ret i64 [[SUB]] ; %sel = select i1 %c, i64 %arg, i64 64 @@ -94,10 +79,7 @@ define i64 @sub_select_not_sext(i1 %c, i64 %arg) { ; CHECK-LABEL: define i64 @sub_select_not_sext ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 [[ARG]], i64 64 -; CHECK-NEXT: [[NOT_C:%.*]] = xor i1 [[C]], true -; CHECK-NEXT: [[EXT_NEG:%.*]] = zext i1 [[NOT_C]] to i64 -; CHECK-NEXT: [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 65 ; CHECK-NEXT: ret i64 [[SUB]] ; %sel = select i1 %c, i64 %arg, i64 64 @@ -122,9 +104,7 @@ define i64 @mul_select_sext(i1 %c) { ; CHECK-LABEL: define i64 @mul_select_sext ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[C]] to i64 -; CHECK-NEXT: [[TMP1:%.*]] = select i1 [[C]], i64 6, i64 0 -; CHECK-NEXT: [[MUL:%.*]] = shl i64 [[EXT]], [[TMP1]] +; CHECK-NEXT: [[MUL:%.*]] = select i1 [[C]], i64 -64, i64 0 ; CHECK-NEXT: ret i64 [[MUL]] ; %sel = select i1 %c, i64 64, i64 1 @@ -168,10 +148,7 @@ define i64 @multiuse_add(i1 %c) { ; CHECK-LABEL: define i64 @multiuse_add ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 -; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]] -; CHECK-NEXT: [[ADD2:%.*]] = add nuw nsw i64 [[ADD]], 1 +; CHECK-NEXT: [[ADD2:%.*]] = select i1 [[C]], i64 66, i64 2 ; CHECK-NEXT: ret i64 [[ADD2]] ; %sel = select i1 %c, i64 64, i64 1 @@ -184,10 +161,7 @@ define i64 @multiuse_select(i1 %c) { ; CHECK-LABEL: define i64 @multiuse_select ; CHECK-SAME: (i1 [[C:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 0 -; CHECK-NEXT: [[EXT_NEG:%.*]] = sext i1 [[C]] to i64 -; CHECK-NEXT: [[ADD:%.*]] = add nsw i64 [[SEL]], [[EXT_NEG]] -; CHECK-NEXT: [[MUL:%.*]] = mul nsw i64 [[SEL]], [[ADD]] +; CHECK-NEXT: [[MUL:%.*]] = select i1 [[C]], i64 4032, i64 0 ; CHECK-NEXT: ret i64 [[MUL]] ; %sel = select i1 %c, i64 64, i64 0 @@ -200,9 +174,8 @@ define i64 @select_non_const_sides(i1 %c, i64 %arg1, i64 %arg2) { ; CHECK-LABEL: define i64 @select_non_const_sides ; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG1:%.*]], i64 [[ARG2:%.*]]) { -; CHECK-NEXT: [[EXT_NEG:%.*]] = sext i1 [[C]] to i64 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 [[ARG1]], i64 [[ARG2]] -; CHECK-NEXT: [[SUB:%.*]] = add i64 [[SEL]], [[EXT_NEG]] +; CHECK-NEXT: [[TMP1:%.*]] = add i64 [[ARG1]], -1 +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[TMP1]], i64 [[ARG2]] ; CHECK-NEXT: ret i64 [[SUB]] ; %ext = zext i1 %c to i64 @@ -214,9 +187,9 @@ define i6 @sub_select_sext_op_swapped_non_const_args(i1 %c, i6 %argT, i6 %argF) { ; CHECK-LABEL: define i6 @sub_select_sext_op_swapped_non_const_args ; CHECK-SAME: (i1 [[C:%.*]], i6 [[ARGT:%.*]], i6 [[ARGF:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i6 [[ARGT]], i6 [[ARGF]] -; CHECK-NEXT: [[EXT:%.*]] = sext i1 [[C]] to i6 -; CHECK-NEXT: [[SUB:%.*]] = sub i6 [[EXT]], [[SEL]] +; CHECK-DAG: [[TMP1:%.*]] = xor i6 [[ARGT]], -1 +; CHECK-DAG: [[TMP2:%.*]] = sub i6 0, [[ARGF]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i6 [[TMP1]], i6 [[TMP2]] ; CHECK-NEXT: ret i6 [[SUB]] ; %sel = select i1 %c, i6 %argT, i6 %argF @@ -228,9 +201,9 @@ define i6 @sub_select_zext_op_swapped_non_const_args(i1 %c, i6 %argT, i6 %argF) { ; CHECK-LABEL: define i6 @sub_select_zext_op_swapped_non_const_args ; CHECK-SAME: (i1 [[C:%.*]], i6 [[ARGT:%.*]], i6 [[ARGF:%.*]]) { -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i6 [[ARGT]], i6 [[ARGF]] -; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i6 -; CHECK-NEXT: [[SUB:%.*]] = sub i6 [[EXT]], [[SEL]] +; CHECK-DAG: [[TMP1:%.*]] = sub i6 1, [[ARGT]] +; CHECK-DAG: [[TMP2:%.*]] = sub i6 0, [[ARGF]] +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i6 [[TMP1]], i6 [[TMP2]] ; CHECK-NEXT: ret i6 [[SUB]] ; %sel = select i1 %c, i6 %argT, i6 %argF