Index: lib/Transforms/InstCombine/InstCombineCompares.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineCompares.cpp +++ lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2169,6 +2169,75 @@ ConstantExpr::getNeg(LHSC)); } break; + case Instruction::Select: + // match a select chain which produces one of three values based on whether + // the LHS is less than, equal to, or greater than RHS respectively. + auto matchThreeWaySignedCompare = [](SelectInst *SI, Value *&LHS, + Value*& RHS, ConstantInt *&Less, + ConstantInt *&Equal, + ConstantInt *&Greater) { + // TODO: Generalize this to work with other comparison idioms or ensure + // they get canonicalized into this form. + + // select i1 (a == b), i32 0, i32 (select i1 (a < b), i32 -1, i32 1), + // except -1, 0, 1 are placeholders for any three constants + ICmpInst::Predicate PredA, PredB; + if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && + match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), + m_Value(RHS))) && + match(SI->getFalseValue(), + m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) { + if (PredA == ICmpInst::ICMP_EQ && + PredB == ICmpInst::ICMP_SLT) + return true; + } + return false; + }; + + // If we're testing a constant value against the result of a three way + // comparison, the result can be expressed directly in terms of the + // original values being compared. Note: We could possibly be more + // aggressive here and remove the hasOneUse test. The original select is + // really likely to simplify or sink when we remove a test of the result. + Value *OrigLHS, *OrigRHS; + ConstantInt *C1, *C2, *C3; + if (ICI.hasOneUse() && + matchThreeWaySignedCompare(cast(LHSI), OrigLHS, OrigRHS, + C1, C2, C3)) { + assert(C1 && C2 && C3 && RHS); + + bool TrueWhenLessThan = + ConstantExpr::getCompare(ICI.getPredicate(), C1, RHS)->isAllOnesValue(); + bool TrueWhenEqual = + ConstantExpr::getCompare(ICI.getPredicate(), C2, RHS)->isAllOnesValue(); + bool TrueWhenGreaterThan = + ConstantExpr::getCompare(ICI.getPredicate(), C3, RHS)->isAllOnesValue(); + + // Eight combinations ... + if (TrueWhenLessThan && TrueWhenEqual && TrueWhenGreaterThan) + return replaceInstUsesWith(ICI, Builder->getTrue()); + else if (TrueWhenLessThan && TrueWhenEqual && !TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_SLE, OrigLHS, OrigRHS); + else if (TrueWhenLessThan && !TrueWhenEqual && TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_NE, OrigLHS, OrigRHS); + else if (TrueWhenLessThan && !TrueWhenEqual && !TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS); + else if (!TrueWhenLessThan && TrueWhenEqual && TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_SGE, OrigLHS, OrigRHS); + else if (!TrueWhenLessThan && TrueWhenEqual && !TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS); + else if (!TrueWhenLessThan && !TrueWhenEqual && TrueWhenGreaterThan) + return new ICmpInst(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS); + else if (!TrueWhenLessThan && !TrueWhenEqual && !TrueWhenGreaterThan) + return replaceInstUsesWith(ICI, Builder->getFalse()); + else + llvm_unreachable("missed a combination"); + } + + // TODO: This could obvious be extended to handle three way unsigned + // compare idioms as well. Also, fcmp idioms? + break; } // Simplify icmp_eq and icmp_ne instructions with integer constant RHS. Index: test/Transforms/InstCombine/compare-3way.ll =================================================================== --- test/Transforms/InstCombine/compare-3way.ll +++ test/Transforms/InstCombine/compare-3way.ll @@ -0,0 +1,395 @@ +; RUN: opt -S -instcombine < %s | FileCheck %s + +declare void @use(i32) + +; These 18 excersice all combinations of signed comparison +; for each of the three values produced by your typical +; 3way compare function (-1, 0, 1) + +define void @test_low_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sgt +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_slt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sge +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_sle +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_ne +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_low_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_low_eq +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sgt +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_slt +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sge +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_sle +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_ne +; CHECK: %cmp = icmp eq i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_mid_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_mid_eq +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sgt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sgt +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sgt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_slt(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_slt +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp slt i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sge(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sge +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sge i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_sle(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_sle +; CHECK: br i1 true, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp sle i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_ne(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_ne +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %normal, label %unreached + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp ne i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @test_high_eq(i64 %a, i64 %b) { +; CHECK-LABEL: @test_high_eq +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -1, i32 1 + %result = select i1 %eq, i32 0, i32 %. + %cmp = icmp eq i32 %result, 1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +; These five make sure we didn't accidentally hard code one of the +; produced values + +define void @non_standard_low(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_low +; CHECK: %cmp = icmp slt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -3 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_mid(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_mid +; CHECK: icmp eq i64 %a, %b +; CHECK: br i1 %eq, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -2 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_high(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_high +; CHECK: %cmp = icmp sgt i64 %a, %b +; CHECK: br i1 %cmp, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -1 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound1(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound1 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, -20 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +} + +define void @non_standard_bound2(i64 %a, i64 %b) { +; CHECK-LABEL: @non_standard_bound2 +; CHECK: br i1 false, label %unreached, label %normal + %eq = icmp eq i64 %a, %b + %slt = icmp slt i64 %a, %b + %. = select i1 %slt, i32 -3, i32 -1 + %result = select i1 %eq, i32 -2, i32 %. + %cmp = icmp eq i32 %result, 0 + br i1 %cmp, label %unreached, label %normal +normal: + ret void +unreached: + call void @use(i32 %result) + ret void +}