Index: llvm/include/llvm/IR/PatternMatch.h =================================================================== --- llvm/include/llvm/IR/PatternMatch.h +++ llvm/include/llvm/IR/PatternMatch.h @@ -35,6 +35,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" #include "llvm/IR/Value.h" @@ -558,6 +559,8 @@ inline bind_ty m_Instruction(Instruction *&I) { return I; } /// Match a binary operator, capturing it if we match. inline bind_ty m_BinOp(BinaryOperator *&I) { return I; } +/// Match a binary operator intrinsic, capturing it if we match. +inline bind_ty m_BinOpIntrinsic(BinaryOpIntrinsic *&I) { return I; } /// Match a ConstantInt, capturing the value if we match. inline bind_ty m_ConstantInt(ConstantInt *&CI) { return CI; } @@ -1937,6 +1940,25 @@ return Signum_match(V); } +template struct ExtractValue_match { + Opnd_t Val; + ExtractValue_match(const Opnd_t &V) : Val(V) {} + + template bool match(OpTy *V) { + if (auto *I = dyn_cast(V)) + return I->getNumIndices() == 1 && I->getIndices()[0] == Ind && + Val.match(I->getAggregateOperand()); + return false; + } +}; + +/// Match a single index ExtractValue instruction. +/// For example m_ExtractValue<1>(...) +template +inline ExtractValue_match m_ExtractValue(const Val_t &V) { + return ExtractValue_match(V); +} + } // end namespace PatternMatch } // end namespace llvm Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -1732,6 +1732,36 @@ return nullptr; } +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + BinaryOpIntrinsic *II; + if (!match(CondVal, m_ExtractValue<1>(m_BinOpIntrinsic(II))) || + !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) + return nullptr; + + Intrinsic::ID NewIntrinsicID; + if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && + match(TrueVal, m_AllOnes())) + // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y + NewIntrinsicID = Intrinsic::uadd_sat; + else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && + match(TrueVal, m_Zero())) + // X - Y overflows ? 0 : X - Y -> usub_sat X, Y + NewIntrinsicID = Intrinsic::usub_sat; + else + return nullptr; + + Function *F = + Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); + return CallInst::Create(F, {II->getArgOperand(0), II->getArgOperand(1)}); +} + Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && @@ -2398,6 +2428,8 @@ if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; + if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) + return Add; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast(TrueVal); Index: llvm/test/Transforms/InstCombine/overflow_to_sat.ll =================================================================== --- llvm/test/Transforms/InstCombine/overflow_to_sat.ll +++ llvm/test/Transforms/InstCombine/overflow_to_sat.ll @@ -3,10 +3,7 @@ define i32 @uadd(i32 %x, i32 %y) { ; CHECK-LABEL: @uadd( -; CHECK-NEXT: [[AO:%.*]] = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) -; CHECK-NEXT: [[O:%.*]] = extractvalue { i32, i1 } [[AO]], 1 -; CHECK-NEXT: [[A:%.*]] = extractvalue { i32, i1 } [[AO]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[O]], i32 -1, i32 [[A]] +; CHECK-NEXT: [[S:%.*]] = call i32 @llvm.uadd.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: ret i32 [[S]] ; %ao = tail call { i32, i1 } @llvm.uadd.with.overflow.i32(i32 %x, i32 %y) @@ -18,10 +15,7 @@ define i32 @usub(i32 %x, i32 %y) { ; CHECK-LABEL: @usub( -; CHECK-NEXT: [[AO:%.*]] = tail call { i32, i1 } @llvm.usub.with.overflow.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) -; CHECK-NEXT: [[O:%.*]] = extractvalue { i32, i1 } [[AO]], 1 -; CHECK-NEXT: [[A:%.*]] = extractvalue { i32, i1 } [[AO]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[O]], i32 0, i32 [[A]] +; CHECK-NEXT: [[S:%.*]] = call i32 @llvm.usub.sat.i32(i32 [[X:%.*]], i32 [[Y:%.*]]) ; CHECK-NEXT: ret i32 [[S]] ; %ao = tail call { i32, i1 } @llvm.usub.with.overflow.i32(i32 %x, i32 %y)