Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2434,6 +2434,77 @@ return nullptr; } +bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, + Value *&RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 + // Greater), where Equal, Less and Greater are placeholders for any three + // constants. + ICmpInst::Predicate PredA, PredB; + if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && + match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && + PredA == ICmpInst::ICMP_EQ && + match(SI->getFalseValue(), + m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), + m_ConstantInt(Less), m_ConstantInt(Greater))) && + PredB == ICmpInst::ICMP_SLT) { + return true; + } + return false; +} + +Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, + Instruction *Select, + ConstantInt *C) { + + assert(C && "Cmp RHS should be a constant int!"); + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; + if (Cmp.hasOneUse() && + matchThreeWayIntCompare(cast(Select), OrigLHS, OrigRHS, + C1LessThan, C2Equal, C3GreaterThan)) { + assert(C1LessThan && C2Equal && C3GreaterThan); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C1LessThan, C) + ->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(Cmp.getPredicate(), C2Equal, C) + ->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(Cmp.getPredicate(), C3GreaterThan, C) + ->isAllOnesValue(); + + // This generates the new instruction that will replace the original Cmp + // Instruction. Instead of enumerating the various combinations when + // TrueWhenLessThan, TrueWhenEqual and TrueWhenGreaterThan are true versus + // false, we rely on chaining of ORs and future passes of InstCombine to + // simplify the OR further (i.e. a s< b || a == b becomes a s<= b). + + // When none of the three constants satisfy the predicate for the RHS (C), + // the entire original Cmp can be simplified to a false. + Value *Cond = Builder->getFalse(); + if (TrueWhenLessThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); + if (TrueWhenEqual) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); + if (TrueWhenGreaterThan) + Cond = Builder->CreateOr(Cond, Builder->CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); + + return replaceInstUsesWith(Cmp, Cond); + } + return nullptr; +} + /// Try to fold integer comparisons with a constant operand: icmp Pred X, C /// where X is some kind of instruction. Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { @@ -2493,11 +2564,28 @@ return I; } + // Match against CmpInst LHS being instructions other than binary operators. Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI)) && - LHSI->getOpcode() == Instruction::Trunc) - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) - return I; + if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { + switch (LHSI->getOpcode()) { + case Instruction::Select: + { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) + return I; + break; + } + case Instruction::Trunc: + if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + return I; + break; + default: + break; + } + } if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) return I; Index: llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h =================================================================== --- llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h +++ llvm/trunk/lib/Transforms/InstCombine/InstCombineInternal.h @@ -603,6 +603,15 @@ Instruction::BinaryOps, Value *, Value *, Value *, Value *); + /// Match a select chain which produces one of three values based on whether + /// the LHS is less than, equal to, or greater than RHS respectively. + /// Return true if we matched a three way compare idiom. The LHS, RHS, Less, + /// Equal and Greater values are saved in the matching process and returned to + /// the caller. + bool matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, Value *&RHS, + ConstantInt *&Less, ConstantInt *&Equal, + ConstantInt *&Greater); + /// \brief Attempts to replace V with a simpler value based on the demanded /// bits. Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBits &Known, @@ -680,6 +689,8 @@ Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select, + ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, const APInt *C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, Index: llvm/trunk/test/Transforms/InstCombine/compare-3way.ll =================================================================== --- llvm/trunk/test/Transforms/InstCombine/compare-3way.ll +++ llvm/trunk/test/Transforms/InstCombine/compare-3way.ll @@ -0,0 +1,395 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +declare void @use(i32) + +; These 18 exercise all combinations of signed comparison +; for each of the three values produced by your typical +; 3way compare function (-1, 0, 1) + +define void @test_low_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sgt +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_slt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sge +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sle +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_ne +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_eq +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sgt +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_slt +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sge +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sle +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_ne +; CHECK: [[TMP1:%.*]] = icmp eq i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_eq +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sgt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_slt +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sge +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sle +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_ne +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_eq +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +; These five make sure we didn't accidentally hard code one of the +; produced values + +define void @non_standard_low(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_low +; CHECK: [[TMP1:%.*]] = icmp slt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -3 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_mid(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_mid +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -2 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_high(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_high +; CHECK: [[TMP1:%.*]] = icmp sgt i64 %a, %b +; CHECK: br i1 [[TMP1]], label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound1(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound1 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -20 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound2(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound2 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +}