Index: llvm/include/llvm/Analysis/ValueTracking.h =================================================================== --- llvm/include/llvm/Analysis/ValueTracking.h +++ llvm/include/llvm/Analysis/ValueTracking.h @@ -101,6 +101,9 @@ const Instruction *CxtI = nullptr, const DominatorTree *DT = nullptr); + /// Return true if the two given values are negation. + bool isKnownNegation(const Value *X, const Value *Y); + /// Returns true if the give value is known to be non-negative. bool isKnownNonNegative(const Value *V, const DataLayout &DL, unsigned Depth = 0, Index: llvm/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/lib/Analysis/ValueTracking.cpp +++ llvm/lib/Analysis/ValueTracking.cpp @@ -4511,6 +4511,26 @@ return {SPF_UNKNOWN, SPNB_NA, false}; } +bool llvm::isKnownNegation(const Value *X, const Value *Y) { + assert(X && Y && "Invalid operand"); + + // X = sub (0, Y) + if (match(X, m_Neg(m_Specific(Y)))) + return true; + + // Y = sub (0, X) + if (match(Y, m_Neg(m_Specific(X)))) + return true; + + // X = sub (A, B), Y = sub (B, A) + Value *A, *B; + if (match(X, m_Sub(m_Value(A), m_Value(B))) && + match(Y, m_Sub(m_Specific(B), m_Specific(A)))) + return true; + + return false; +} + static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, @@ -4614,30 +4634,52 @@ // match against either LHS or sext(LHS). auto MaybeSExtLHS = m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS))); - if ((match(TrueVal, MaybeSExtLHS) && - match(FalseVal, m_Neg(m_Specific(TrueVal)))) || - (match(FalseVal, MaybeSExtLHS) && - match(TrueVal, m_Neg(m_Specific(FalseVal))))) { + if ((match(TrueVal, MaybeSExtLHS) || match(FalseVal, MaybeSExtLHS)) && + isKnownNegation(TrueVal, FalseVal)) { // Set LHS and RHS so that RHS is the negated operand of the select + bool cmpEqNegated = match(CmpLHS, m_Neg(m_Specific(TrueVal))) || + match(CmpLHS, m_Neg(m_Specific(FalseVal))); if (match(TrueVal, MaybeSExtLHS)) { - LHS = TrueVal; - RHS = FalseVal; + if (!cmpEqNegated) { + LHS = TrueVal; + RHS = FalseVal; + } else { + LHS = FalseVal; + RHS = TrueVal; + } } else { - LHS = FalseVal; - RHS = TrueVal; + if (!cmpEqNegated) { + LHS = FalseVal; + RHS = TrueVal; + } else { + LHS = TrueVal; + RHS = FalseVal; + } } // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X) // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X) - if (Pred == ICmpInst::ICMP_SGT && + if (!cmpEqNegated && Pred == ICmpInst::ICMP_SGT && match(CmpRHS, m_CombineOr(m_ZeroInt(), m_AllOnes()))) return {(LHS == TrueVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; // (X ABS(X) // (X NABS(X) - if (Pred == ICmpInst::ICMP_SLT && + if (!cmpEqNegated && Pred == ICmpInst::ICMP_SLT && match(CmpRHS, m_CombineOr(m_ZeroInt(), m_One()))) return {(LHS == FalseVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; + + // (-X >s 0) ? -X : X and (-X >s -1) ? -X : X --> ABS(-X) + // (-X >s 0) ? X : -X and (-X >s -1) ? X : -X --> NABS(-X) + if (cmpEqNegated && Pred == ICmpInst::ICMP_SGT && + match(CmpRHS, m_CombineOr(m_ZeroInt(), m_AllOnes()))) + return {(LHS == FalseVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; + + // (-X ABS(-X) + // (-X NABS(-X) + if (cmpEqNegated && Pred == ICmpInst::ICMP_SLT && + match(CmpRHS, m_CombineOr(m_ZeroInt(), m_One()))) + return {(LHS == TrueVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; } if (CmpInst::isIntPredicate(Pred)) Index: llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -811,28 +811,54 @@ SPF != SelectPatternFlavor::SPF_NABS) return nullptr; + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + assert(isKnownNegation(TVal, FVal) && + "Unexpected result from matchSelectPattern"); + + // need to consider about pattern like RHS = sub (A, B), so can not use RHS to + // match operand 0. + bool cmpEqNegated = match(Cmp.getOperand(0), m_Neg(m_Specific(TVal))) || + match(Cmp.getOperand(0), m_Neg(m_Specific(FVal))); + // Is this already canonical? - if (match(Cmp.getOperand(1), m_ZeroInt()) && - Cmp.getPredicate() == ICmpInst::ICMP_SLT) + if (!cmpEqNegated && match(Cmp.getOperand(1), m_ZeroInt()) && + Cmp.getPredicate() == ICmpInst::ICMP_SLT && + match(RHS, m_Neg(m_Specific(LHS)))) return nullptr; - // Create the canonical compare. + // Create the canonical compare: icmp slt LHS 0. Cmp.setPredicate(ICmpInst::ICMP_SLT); Cmp.setOperand(1, ConstantInt::getNullValue(Cmp.getOperand(0)->getType())); + if (cmpEqNegated) + Cmp.setOperand(0, LHS); + + // Create the canoical RHS: RHS = sub (0, LHS). + if (!match(RHS, m_Neg(m_Specific(LHS)))) { + assert(match(RHS, m_Sub(m_Value(), m_Value())) && + "RHS should be negated value"); + if (RHS == TVal) { + auto *TInst = dyn_cast(TVal); + TInst->setOperand( + 0, ConstantInt::getNullValue(TInst->getOperand(0)->getType())); + TInst->setOperand(1, FVal); + } else { + auto *FInst = dyn_cast(FVal); + FInst->setOperand( + 0, ConstantInt::getNullValue(FInst->getOperand(0)->getType())); + FInst->setOperand(1, TVal); + } + } // If the select operands do not change, we're done. - Value *TVal = Sel.getTrueValue(); - Value *FVal = Sel.getFalseValue(); if (SPF == SelectPatternFlavor::SPF_NABS) { - if (TVal == LHS && match(FVal, m_Neg(m_Specific(TVal)))) + if (TVal == LHS) return &Sel; - assert(FVal == LHS && match(TVal, m_Neg(m_Specific(FVal))) && - "Unexpected results from matchSelectPattern"); + assert(FVal == LHS && "Unexpected results from matchSelectPattern"); } else { - if (FVal == LHS && match(TVal, m_Neg(m_Specific(FVal)))) + if (FVal == LHS) return &Sel; - assert(TVal == LHS && match(FVal, m_Neg(m_Specific(TVal))) && - "Unexpected results from matchSelectPattern"); + assert(TVal == LHS && "Unexpected results from matchSelectPattern"); } // We are swapping the select operands, so swap the metadata too. Index: llvm/test/Transforms/InstCombine/abs-1.ll =================================================================== --- llvm/test/Transforms/InstCombine/abs-1.ll +++ llvm/test/Transforms/InstCombine/abs-1.ll @@ -134,9 +134,9 @@ define i32 @abs_canonical_6(i32 %a, i32 %b) { ; CHECK-LABEL: @abs_canonical_6( ; CHECK-NEXT: [[TMP1:%.*]] = sub i32 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[TMP1]], -1 -; CHECK-NEXT: [[TMP2:%.*]] = sub i32 [[B]], [[A]] -; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP1]], i32 [[TMP2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = sub i32 0, [[TMP1]] +; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP2]], i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[ABS]] ; %tmp1 = sub i32 %a, %b @@ -149,9 +149,9 @@ define <2 x i8> @abs_canonical_7(<2 x i8> %a, <2 x i8 > %b) { ; CHECK-LABEL: @abs_canonical_7( ; CHECK-NEXT: [[TMP1:%.*]] = sub <2 x i8> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP2:%.*]] = sub <2 x i8> [[B]], [[A]] -; CHECK-NEXT: [[ABS:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[TMP1]], <2 x i8> [[TMP2]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i8> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]] +; CHECK-NEXT: [[ABS:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[TMP2]], <2 x i8> [[TMP1]] ; CHECK-NEXT: ret <2 x i8> [[ABS]] ; @@ -165,8 +165,8 @@ define i32 @abs_canonical_8(i32 %a) { ; CHECK-LABEL: @abs_canonical_8( ; CHECK-NEXT: [[TMP:%.*]] = sub i32 0, [[A:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[TMP]], 0 -; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[A]], i32 [[TMP]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A]], 0 +; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP]], i32 [[A]] ; CHECK-NEXT: ret i32 [[ABS]] ; %tmp = sub i32 0, %a @@ -266,9 +266,9 @@ define i32 @nabs_canonical_6(i32 %a, i32 %b) { ; CHECK-LABEL: @nabs_canonical_6( ; CHECK-NEXT: [[TMP1:%.*]] = sub i32 [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i32 [[TMP1]], -1 -; CHECK-NEXT: [[TMP2:%.*]] = sub i32 [[B]], [[A]] -; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP2]], i32 [[TMP1]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[TMP1]], 0 +; CHECK-NEXT: [[TMP2:%.*]] = sub i32 0, [[TMP1]] +; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP1]], i32 [[TMP2]] ; CHECK-NEXT: ret i32 [[ABS]] ; %tmp1 = sub i32 %a, %b @@ -281,9 +281,9 @@ define <2 x i8> @nabs_canonical_7(<2 x i8> %a, <2 x i8 > %b) { ; CHECK-LABEL: @nabs_canonical_7( ; CHECK-NEXT: [[TMP1:%.*]] = sub <2 x i8> [[A:%.*]], [[B:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt <2 x i8> [[TMP1]], -; CHECK-NEXT: [[TMP2:%.*]] = sub <2 x i8> [[B]], [[A]] -; CHECK-NEXT: [[ABS:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[TMP2]], <2 x i8> [[TMP1]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i8> [[TMP1]], zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = sub <2 x i8> zeroinitializer, [[TMP1]] +; CHECK-NEXT: [[ABS:%.*]] = select <2 x i1> [[CMP]], <2 x i8> [[TMP1]], <2 x i8> [[TMP2]] ; CHECK-NEXT: ret <2 x i8> [[ABS]] ; %tmp1 = sub <2 x i8> %a, %b @@ -296,8 +296,8 @@ define i32 @nabs_canonical_8(i32 %a) { ; CHECK-LABEL: @nabs_canonical_8( ; CHECK-NEXT: [[TMP:%.*]] = sub i32 0, [[A:%.*]] -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[TMP]], 0 -; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[TMP]], i32 [[A]] +; CHECK-NEXT: [[CMP:%.*]] = icmp slt i32 [[A]], 0 +; CHECK-NEXT: [[ABS:%.*]] = select i1 [[CMP]], i32 [[A]], i32 [[TMP]] ; CHECK-NEXT: ret i32 [[ABS]] ; %tmp = sub i32 0, %a