Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1613,6 +1613,9 @@ if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return Changed ? &I : nullptr; } @@ -2461,6 +2464,9 @@ } } + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return TryToNarrowDeduceFlags(); } Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -456,6 +456,12 @@ // -> (BinOp (logic_shift (BinOp X, Y)), Mask) Instruction *foldBinOpShiftWithShift(BinaryOperator &I); + /// Tries to simplify binops of select and cast of the select condition. + /// + /// (Binop (cast C), (select C, T, F)) + /// -> (select C, C0, C1) + Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorizationFolds(BinaryOperator &I); Index: llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -474,6 +474,9 @@ if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + // min(X, Y) * max(X, Y) => X * Y. if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), m_c_SMin(m_Deferred(X), m_Deferred(Y))), Index: llvm/lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -866,6 +866,63 @@ return MatchBinOp(1); } +// (Binop (zext C), (select C, T, F)) +// -> (select C, (binop 1, T), (binop 0, F)) +// +// (Binop (sext C), (select C, T, F)) +// -> (select C, (binop -1, T), (binop 0, F)) +// +// Attempt to simplify binary operations into a select with folded args, when +// the operands of the binop are a select operation with constant arguments and +// the other operand is a zext/sext extension, whose value is the select +// condition. +Instruction * +InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { + Instruction::BinaryOps Opc = I.getOpcode(); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *CondVal, *TrueVal, *FalseVal; + + // Make sure the left-hand side operand of the binop is a select instruction. + if (!match(LHS, + m_Select(m_Value(CondVal), m_Value(TrueVal), m_Value(FalseVal)))) + return nullptr; + + // Make sure the right-hand side operand of the binop is a zero/sign extension + // operating on a i1. + if (!(match(RHS, m_ZExtOrSExt(m_Value(A))) && + A->getType()->getScalarSizeInBits() == 1)) + return nullptr; + + auto NewFoldedConst = [&](bool IsZExt, Value *V) { + Value *CastConst = + IsZExt ? Constant::getIntegerValue( + V->getType(), APInt(V->getType()->getIntegerBitWidth(), 1)) + : Constant::getAllOnesValue(V->getType()); + return Builder.CreateBinOp(Opc, V, CastConst); + }; + + // If the value used in the zext/sext is the select condition, and the true + // argument is a constant, the binop can be simplified. + if (CondVal == A) { + Value *FConstVal = Builder.CreateBinOp( + Opc, FalseVal, Constant::getNullValue(FalseVal->getType())); + return SelectInst::Create( + CondVal, NewFoldedConst(isa(RHS), TrueVal), FConstVal); + } + + // If the value used in the zext/sext is the negated of the select + // condition, and the false argument is a constant, the binop can be + // simplified. + if (match(A, m_Not(m_Specific(CondVal)))) { + Value *TConstVal = Builder.CreateBinOp( + Opc, TrueVal, Constant::getNullValue(FalseVal->getType())); + return SelectInst::Create(CondVal, TConstVal, + NewFoldedConst(isa(RHS), FalseVal)); + } + + return nullptr; +} + Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast(LHS); Index: llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll @@ -0,0 +1,159 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i64 @add_select_zext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_zext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 65, i64 1 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_sext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_sext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 63, i64 1 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = sext i1 %c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_not_zext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_not_zext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 2 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %not.c = xor i1 %c, true + %ext = zext i1 %not.c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_not_sext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_not_sext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 0 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %not.c = xor i1 %c, true + %ext = sext i1 %not.c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @sub_select_sext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_sext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 65, i64 [[ARG]] +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 64, i64 %arg + %ext = sext i1 %c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define i64 @sub_select_not_zext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_not_zext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 63 +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 %arg, i64 64 + %not.c = xor i1 %c, true + %ext = zext i1 %not.c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define i64 @sub_select_not_sext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_not_sext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 65 +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 %arg, i64 64 + %not.c = xor i1 %c, true + %ext = sext i1 %not.c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define i64 @mul_select_zext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @mul_select_zext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[MUL:%.*]] = select i1 [[C]], i64 [[ARG]], i64 0 +; CHECK-NEXT: ret i64 [[MUL]] +; + %sel = select i1 %c, i64 %arg, i64 1 + %ext = zext i1 %c to i64 + %mul = mul i64 %sel, %ext + ret i64 %mul +} + +define i64 @mul_select_sext(i1 %c) { +; CHECK-LABEL: define i64 @mul_select_sext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[MUL:%.*]] = select i1 [[C]], i64 -64, i64 0 +; CHECK-NEXT: ret i64 [[MUL]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = sext i1 %c to i64 + %mul = mul i64 %sel, %ext + ret i64 %mul +} + +define i64 @select_zext_different_condition(i1 %c, i1 %d) { +; CHECK-LABEL: define i64 @select_zext_different_condition +; CHECK-SAME: (i1 [[C:%.*]], i1 [[D:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[D]] to i64 +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %d to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @multi_use_add(i1 %c) { +; CHECK-LABEL: define i64 @multi_use_add +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD2:%.*]] = select i1 [[C]], i64 66, i64 2 +; CHECK-NEXT: ret i64 [[ADD2]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %c to i64 + %add = add i64 %sel, %ext + %add2 = add i64 %add, 1 + ret i64 %add2 +} + +define <2 x i64> @vector_test(i1 %c) { +; CHECK-LABEL: define <2 x i64> @vector_test +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], <2 x i64> , <2 x i64> +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i64 +; CHECK-NEXT: [[VEC0:%.*]] = insertelement <2 x i64> undef, i64 [[EXT]], i64 0 +; CHECK-NEXT: [[VEC1:%.*]] = shufflevector <2 x i64> [[VEC0]], <2 x i64> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw <2 x i64> [[SEL]], [[VEC1]] +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %sel = select i1 %c, <2 x i64> , <2 x i64> + %ext = zext i1 %c to i64 + %vec0 = insertelement <2 x i64> undef, i64 %ext, i32 0 + %vec1 = insertelement <2 x i64> %vec0, i64 %ext, i32 1 + %add = add <2 x i64> %sel, %vec1 + ret <2 x i64> %add +}