diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1732,6 +1732,36 @@ return nullptr; } +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + WithOverflowInst *II; + if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || + !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) + return nullptr; + + Intrinsic::ID NewIntrinsicID; + if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && + match(TrueVal, m_AllOnes())) + // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y + NewIntrinsicID = Intrinsic::uadd_sat; + else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && + match(TrueVal, m_Zero())) + // X - Y overflows ? 0 : X - Y -> usub_sat X, Y + NewIntrinsicID = Intrinsic::usub_sat; + else + return nullptr; + + Function *F = + Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); + return CallInst::Create(F, {II->getArgOperand(0), II->getArgOperand(1)}); +} + Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && @@ -2398,6 +2428,8 @@ if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; + if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) + return Add; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast(TrueVal); diff --git a/llvm/test/Transforms/InstCombine/overflow_to_sat.ll b/llvm/test/Transforms/InstCombine/overflow_to_sat.ll --- a/llvm/test/Transforms/InstCombine/overflow_to_sat.ll +++ b/llvm/test/Transforms/InstCombine/overflow_to_sat.ll @@ -3,10 +3,7 @@ define i32 @uadd(i32 %x, i32 %y) { ; CHECK-LABEL: @uadd( -; CHECK-NEXT: [[AO:%.*]] = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) -; CHECK-NEXT: [[O:%.*]] = extractvalue { i32, i1 } [[AO]], 1 -; CHECK-NEXT: [[A:%.*]] = extractvalue { i32, i1 } [[AO]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[O]], i32 -1, i32 [[A]] +; CHECK-NEXT: [[S:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: ret i32 [[S]] ; %ao = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %x, i32 %y) @@ -18,10 +15,7 @@ define i32 @usub(i32 %x, i32 %y) { ; CHECK-LABEL: @usub( -; CHECK-NEXT: [[AO:%.*]] = tail call { i32, i1 } @llvm.usub.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) -; CHECK-NEXT: [[O:%.*]] = extractvalue { i32, i1 } [[AO]], 1 -; CHECK-NEXT: [[A:%.*]] = extractvalue { i32, i1 } [[AO]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[O]], i32 0, i32 [[A]] +; CHECK-NEXT: [[S:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: ret i32 [[S]] ; %ao = tail call { i32, i1 } @llvm.usub.with.overflow.i32(i32 %x, i32 %y)