Index: llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1613,6 +1613,9 @@ if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return Changed ? &I : nullptr; } @@ -1838,6 +1841,9 @@ return Result; } + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return nullptr; } @@ -2461,6 +2467,9 @@ } } + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return TryToNarrowDeduceFlags(); } @@ -2765,5 +2774,8 @@ } } + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + return nullptr; } Index: llvm/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -456,6 +456,12 @@ // -> (BinOp (logic_shift (BinOp X, Y)), Mask) Instruction *foldBinOpShiftWithShift(BinaryOperator &I); + /// Tries to simplify binops of select and cast of the select condition. + /// + /// (Binop (cast C), (select C, T, F)) + // -> (select C, C0, C1) + Instruction *foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorizationFolds(BinaryOperator &I); Index: llvm/lib/Transforms/InstCombine/InstructionCombining.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -866,6 +866,82 @@ return MatchBinOp(1); } +Instruction * +InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { + Instruction::BinaryOps Opc = I.getOpcode(); + + switch (Opc) { + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + break; + default: + return nullptr; + } + + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + Value *A, *CondVal, *TrueVal, *FalseVal; + + if (!match(LHS, + m_Select(m_Value(CondVal), m_Value(TrueVal), m_Value(FalseVal)))) + return nullptr; + + if (match(RHS, m_ZExtOrSExt(m_Value(A))) && + A->getType()->getScalarSizeInBits() == 1) { + assert(Opc == Instruction::Add || Opc == Instruction::Sub); + + auto NewFoldedConst = [&](bool IsZExt, Constant *C) { + if (IsZExt) + return Opc == Instruction::Add ? InstCombiner::AddOne(C) + : InstCombiner::SubOne(C); + return Opc == Instruction::Add ? InstCombiner::SubOne(C) + : InstCombiner::AddOne(C); + }; + + if (CondVal == A) + if (Constant *TrueC = dyn_cast(TrueVal)) + return SelectInst::Create( + CondVal, NewFoldedConst(isa(RHS), TrueC), FalseVal); + + if (match(A, m_Not(m_Specific(CondVal)))) + if (Constant *FalseC = dyn_cast(FalseVal)) + return SelectInst::Create(CondVal, TrueVal, + NewFoldedConst(isa(RHS), FalseC)); + } + + if ((match(RHS, m_UIToFP(m_Value(A))) || match(RHS, m_SIToFP(m_Value(A)))) && + A->getType()->getScalarSizeInBits() == 1) { + if (isa(TrueVal) && isa(FalseVal)) { + assert(Opc == Instruction::FAdd || Opc == Instruction::FSub); + + auto NewFoldedConst = [&](bool IsUIToFP, Constant *C) { + Constant *One = ConstantFP::get(C->getType(), 1.0); + if (IsUIToFP) + return Opc == Instruction::FAdd + ? ConstantFoldBinaryInstruction(Instruction::FAdd, C, One) + : ConstantFoldBinaryInstruction(Instruction::FSub, C, One); + return Opc == Instruction::FAdd + ? ConstantFoldBinaryInstruction(Instruction::FSub, C, One) + : ConstantFoldBinaryInstruction(Instruction::FAdd, C, One); + }; + + if (CondVal == A) + return SelectInst::Create( + CondVal, + NewFoldedConst(isa(RHS), cast(TrueVal)), + FalseVal); + + if (match(A, m_Not(m_Specific(CondVal)))) + return SelectInst::Create( + CondVal, TrueVal, + NewFoldedConst(isa(RHS), cast(FalseVal))); + } + } + + return nullptr; +} + Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast(LHS); Index: llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll =================================================================== --- /dev/null +++ llvm/test/Transforms/InstCombine/binop-select-cast-of-select-cond.ll @@ -0,0 +1,197 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes=instcombine -S | FileCheck %s + +define i64 @add_select_zext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_zext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 65, i64 1 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_sext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_sext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 63, i64 1 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = sext i1 %c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_not_zext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_not_zext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 2 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %not.c = xor i1 %c, true + %ext = zext i1 %not.c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @add_select_not_sext(i1 %c) { +; CHECK-LABEL: define i64 @add_select_not_sext +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD:%.*]] = select i1 [[C]], i64 64, i64 0 +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %not.c = xor i1 %c, true + %ext = sext i1 %not.c to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @sub_select_sext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_sext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 65, i64 [[ARG]] +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 64, i64 %arg + %ext = sext i1 %c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define i64 @sub_select_not_zext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_not_zext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 63 +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 %arg, i64 64 + %not.c = xor i1 %c, true + %ext = zext i1 %not.c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define i64 @sub_select_not_sext(i1 %c, i64 %arg) { +; CHECK-LABEL: define i64 @sub_select_not_sext +; CHECK-SAME: (i1 [[C:%.*]], i64 [[ARG:%.*]]) { +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i64 [[ARG]], i64 65 +; CHECK-NEXT: ret i64 [[SUB]] +; + %sel = select i1 %c, i64 %arg, i64 64 + %not.c = xor i1 %c, true + %ext = sext i1 %not.c to i64 + %sub = sub i64 %sel, %ext + ret i64 %sub +} + +define float @fadd_select_uitofp(i1 %c) { +; CHECK-LABEL: define float @fadd_select_uitofp +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[FADD:%.*]] = select i1 [[C]], float 3.000000e+00, float 1.000000e+00 +; CHECK-NEXT: ret float [[FADD]] +; + %sel = select i1 %c, float 2.0, float 1.0 + %conv = uitofp i1 %c to float + %fadd = fadd float %sel, %conv + ret float %fadd +} + +define float @fadd_select_not_sitofp(i1 %c) { +; CHECK-LABEL: define float @fadd_select_not_sitofp +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[FADD:%.*]] = select i1 [[C]], float 2.000000e+00, float 0.000000e+00 +; CHECK-NEXT: ret float [[FADD]] +; + %sel = select i1 %c, float 2.0, float 1.0 + %not.c = xor i1 %c, true + %conv = sitofp i1 %not.c to float + %fadd = fadd float %sel, %conv + ret float %fadd +} + +define float @fsub_select_sitofp(i1 %c) { +; CHECK-LABEL: define float @fsub_select_sitofp +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[FSUB:%.*]] = select i1 [[C]], float 3.000000e+00, float 1.000000e+00 +; CHECK-NEXT: ret float [[FSUB]] +; + %sel = select i1 %c, float 2.0, float 1.0 + %conv = sitofp i1 %c to float + %fsub = fsub float %sel, %conv + ret float %fsub +} + +define float @fsub_select_not_uitofp(i1 %c) { +; CHECK-LABEL: define float @fsub_select_not_uitofp +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[FSUB:%.*]] = select i1 [[C]], float 2.000000e+00, float 0.000000e+00 +; CHECK-NEXT: ret float [[FSUB]] +; + %sel = select i1 %c, float 2.0, float 1.0 + %not.c = xor i1 %c, true + %conv = uitofp i1 %not.c to float + %fsub = fsub float %sel, %conv + ret float %fsub +} + +define float @fsub_select_not_sitofp(i1 %c) { +; CHECK-LABEL: define float @fsub_select_not_sitofp +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: ret float 2.000000e+00 +; + %sel = select i1 %c, float 2.0, float 1.0 + %not.c = xor i1 %c, true + %conv = sitofp i1 %not.c to float + %fsub = fsub float %sel, %conv + ret float %fsub +} + +define i64 @select_zext_different_condition(i1 %c, i1 %d) { +; CHECK-LABEL: define i64 @select_zext_different_condition +; CHECK-SAME: (i1 [[C:%.*]], i1 [[D:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], i64 64, i64 1 +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[D]] to i64 +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw i64 [[SEL]], [[EXT]] +; CHECK-NEXT: ret i64 [[ADD]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %d to i64 + %add = add i64 %sel, %ext + ret i64 %add +} + +define i64 @multi_use_add(i1 %c) { +; CHECK-LABEL: define i64 @multi_use_add +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[ADD2:%.*]] = select i1 [[C]], i64 66, i64 2 +; CHECK-NEXT: ret i64 [[ADD2]] +; + %sel = select i1 %c, i64 64, i64 1 + %ext = zext i1 %c to i64 + %add = add i64 %sel, %ext + %add2 = add i64 %add, 1 + ret i64 %add2 +} + +define <2 x i64> @vector_test(i1 %c) { +; CHECK-LABEL: define <2 x i64> @vector_test +; CHECK-SAME: (i1 [[C:%.*]]) { +; CHECK-NEXT: [[SEL:%.*]] = select i1 [[C]], <2 x i64> , <2 x i64> +; CHECK-NEXT: [[EXT:%.*]] = zext i1 [[C]] to i64 +; CHECK-NEXT: [[VEC0:%.*]] = insertelement <2 x i64> undef, i64 [[EXT]], i64 0 +; CHECK-NEXT: [[VEC1:%.*]] = shufflevector <2 x i64> [[VEC0]], <2 x i64> poison, <2 x i32> zeroinitializer +; CHECK-NEXT: [[ADD:%.*]] = add nuw nsw <2 x i64> [[SEL]], [[VEC1]] +; CHECK-NEXT: ret <2 x i64> [[ADD]] +; + %sel = select i1 %c, <2 x i64> , <2 x i64> + %ext = zext i1 %c to i64 + %vec0 = insertelement <2 x i64> undef, i64 %ext, i32 0 + %vec1 = insertelement <2 x i64> %vec0, i64 %ext, i32 1 + %add = add <2 x i64> %sel, %vec1 + ret <2 x i64> %add +}