diff --git a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h --- a/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h +++ b/llvm/include/llvm/Transforms/InstCombine/InstCombiner.h @@ -213,6 +213,17 @@ Pred, Constant *C); + static bool shouldAvoidAbsorbingNotIntoSelect(const SelectInst &SI) { + // a ? b : false and a ? true : b are the canonical form of logical and/or. + // This includes !a ? b : false and !a ? true : b. Absorbing the not into + // the select by swapping operands would break recognition of this pattern + // in other analyses, so don't do that. + return match(&SI, PatternMatch::m_LogicalAnd(PatternMatch::m_Value(), + PatternMatch::m_Value())) || + match(&SI, PatternMatch::m_LogicalOr(PatternMatch::m_Value(), + PatternMatch::m_Value())); + } + /// Return true if the specified value is free to invert (apply ~ to). /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all @@ -267,6 +278,8 @@ case Instruction::Select: if (U.getOperandNo() != 0) // Only if the value is used as select cond. return false; + if (shouldAvoidAbsorbingNotIntoSelect(*cast(I))) + return false; break; case Instruction::Br: assert(U.getOperandNo() == 0 && "Must be branching on that value."); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -47,6 +47,11 @@ #define DEBUG_TYPE "instcombine" +/// FIXME: Enabled by default until the pattern is supported well. +static cl::opt EnableUnsafeSelectTransform( + "instcombine-unsafe-select-transform", cl::init(true), + cl::desc("Enable poison-unsafe select to and/or transform")); + static Value *createMinMax(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { CmpInst::Predicate Pred = getMinMaxPred(SPF); @@ -2567,38 +2572,43 @@ if (SelType->isIntOrIntVectorTy(1) && TrueVal->getType() == CondVal->getType()) { - if (match(TrueVal, m_One())) { + if (EnableUnsafeSelectTransform && match(TrueVal, m_One())) { // Change: A = select B, true, C --> A = or B, C return BinaryOperator::CreateOr(CondVal, FalseVal); } - if (match(TrueVal, m_Zero())) { - // Change: A = select B, false, C --> A = and !B, C - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateAnd(NotCond, FalseVal); - } - if (match(FalseVal, m_Zero())) { + if (EnableUnsafeSelectTransform && match(FalseVal, m_Zero())) { // Change: A = select B, C, false --> A = and B, C return BinaryOperator::CreateAnd(CondVal, TrueVal); } + + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Zero())) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, + ConstantInt::getFalse(SelType)); + } + // select a, b, true -> select !a, true, b if (match(FalseVal, m_One())) { - // Change: A = select B, C, true --> A = or !B, C Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return BinaryOperator::CreateOr(NotCond, TrueVal); + return SelectInst::Create(NotCond, ConstantInt::getTrue(SelType), + TrueVal); } - // select a, a, b -> a | b - // select a, b, a -> a & b + // select a, a, b -> select a, true, b if (CondVal == TrueVal) - return BinaryOperator::CreateOr(CondVal, FalseVal); + return replaceOperand(SI, 1, ConstantInt::getTrue(SelType)); + // select a, b, a -> select a, b, false if (CondVal == FalseVal) - return BinaryOperator::CreateAnd(CondVal, TrueVal); + return replaceOperand(SI, 2, ConstantInt::getFalse(SelType)); - // select a, ~a, b -> (~a) & b - // select a, b, ~a -> (~a) | b + // select a, !a, b -> select !a, b, false if (match(TrueVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateAnd(TrueVal, FalseVal); + return SelectInst::Create(TrueVal, FalseVal, + ConstantInt::getFalse(SelType)); + // select a, b, !a -> select !a, true, b if (match(FalseVal, m_Not(m_Specific(CondVal)))) - return BinaryOperator::CreateOr(TrueVal, FalseVal); + return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType), + TrueVal); } // Selecting between two integer or vector splat integer constants? @@ -2942,7 +2952,8 @@ } Value *NotCond; - if (match(CondVal, m_Not(m_Value(NotCond)))) { + if (match(CondVal, m_Not(m_Value(NotCond))) && + !InstCombiner::shouldAvoidAbsorbingNotIntoSelect(SI)) { replaceOperand(SI, 0, NotCond); SI.swapValues(); SI.swapProfMetadata(); diff --git a/llvm/test/Transforms/InstCombine/select-and-or.ll b/llvm/test/Transforms/InstCombine/select-and-or.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/select-and-or.ll @@ -0,0 +1,87 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt -S -instcombine -instcombine-unsafe-select-transform=0 < %s | FileCheck %s + +; Should not be converted to "and", which has different poison semantics. +define i1 @logical_and(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_and( +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 %b, i1 false + ret i1 %res +} + +; Should not be converted to "or", which has different poison semantics. +define i1 @logical_or(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_or( +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]] +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 true, i1 %b + ret i1 %res +} +; Canonicalize to logical and form, even if that requires adding a "not". +define i1 @logical_and_not(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_and_not( +; CHECK-NEXT: [[NOT_A:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[RES:%.*]] = select i1 [[NOT_A]], i1 [[B:%.*]], i1 false +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 false, i1 %b + ret i1 %res +} + +; Canonicalize to logical or form, even if that requires adding a "not". +define i1 @logical_or_not(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_or_not( +; CHECK-NEXT: [[NOT_A:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[RES:%.*]] = select i1 [[NOT_A]], i1 true, i1 [[B:%.*]] +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 %b, i1 true + ret i1 %res +} + +; These are variants where condition or !condition is used to represent true +; or false in one of the select arms. It should be canonicalized to the +; constants. + +define i1 @logical_and_cond_reuse(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_and_cond_reuse( +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A:%.*]], i1 [[B:%.*]], i1 false +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 %b, i1 %a + ret i1 %res +} + +define i1 @logical_or_cond_reuse(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_or_cond_reuse( +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A:%.*]], i1 true, i1 [[B:%.*]] +; CHECK-NEXT: ret i1 [[RES]] +; + %res = select i1 %a, i1 %a, i1 %b + ret i1 %res +} + +define i1 @logical_and_not_cond_reuse(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_and_not_cond_reuse( +; CHECK-NEXT: [[A_NOT:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A_NOT]], i1 true, i1 [[B:%.*]] +; CHECK-NEXT: ret i1 [[RES]] +; + %a.not = xor i1 %a, true + %res = select i1 %a, i1 %b, i1 %a.not + ret i1 %res +} + +define i1 @logical_or_not_cond_reuse(i1 %a, i1 %b) { +; CHECK-LABEL: @logical_or_not_cond_reuse( +; CHECK-NEXT: [[A_NOT:%.*]] = xor i1 [[A:%.*]], true +; CHECK-NEXT: [[RES:%.*]] = select i1 [[A_NOT]], i1 [[B:%.*]], i1 false +; CHECK-NEXT: ret i1 [[RES]] +; + %a.not = xor i1 %a, true + %res = select i1 %a, i1 %a.not, i1 %b + ret i1 %res +}