diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -599,6 +599,12 @@ return Result; } + /// Determine the pattern that a select with the given compare as its + /// predicate and given values as its true/false operands would match. + SelectPatternResult matchDecomposedSelectPattern( + CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp = nullptr, unsigned Depth = 0); + /// Return the canonical comparison predicate for the specified /// minimum/maximum flavor. CmpInst::Predicate getMinMaxPred(SelectPatternFlavor SPF, diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5067,11 +5067,19 @@ CmpInst *CmpI = dyn_cast(SI->getCondition()); if (!CmpI) return {SPF_UNKNOWN, SPNB_NA, false}; + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + + return llvm::matchDecomposedSelectPattern(CmpI, TrueVal, FalseVal, LHS, RHS, + CastOp, Depth); +} + +SelectPatternResult llvm::matchDecomposedSelectPattern( + CmpInst *CmpI, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, + Instruction::CastOps *CastOp, unsigned Depth) { CmpInst::Predicate Pred = CmpI->getPredicate(); Value *CmpLHS = CmpI->getOperand(0); Value *CmpRHS = CmpI->getOperand(1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); FastMathFlags FMF; if (isa(CmpI)) FMF = CmpI->getFastMathFlags(); diff --git a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp --- a/llvm/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -80,6 +80,11 @@ cl::desc("Enable imprecision in EarlyCSE in pathological cases, in exchange " "for faster compile. Caps the MemorySSA clobbering calls.")); +static cl::opt EarlyCSEDebugHash( + "earlycse-debug-hash", cl::init(false), cl::Hidden, + cl::desc("Perform extra assertion checking to verify that SimpleValue's hash " + "function is well-behaved w.r.t. its isEqual predicate")); + //===----------------------------------------------------------------------===// // SimpleValue //===----------------------------------------------------------------------===// @@ -130,22 +135,34 @@ } // end namespace llvm -/// Match a 'select' including an optional 'not' of the condition. -static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, - Value *&T, Value *&F) { - if (match(V, m_Select(m_Value(Cond), m_Value(T), m_Value(F)))) { - // Look through a 'not' of the condition operand by swapping true/false. - Value *CondNot; - if (match(Cond, m_Not(m_Value(CondNot)))) { - Cond = CondNot; - std::swap(T, F); - } - return true; +/// Match a 'select' including an optional 'not's of the condition. +static bool matchSelectWithOptionalNotCond(Value *V, Value *&Cond, Value *&A, + Value *&B, + SelectPatternFlavor &Flavor) { + // Return false if V is not even a select. + if (!match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B)))) + return false; + + // Look through a 'not' of the condition operand by swapping A/B. + Value *CondNot; + if (match(Cond, m_Not(m_Value(CondNot)))) { + Cond = CondNot; + std::swap(A, B); } - return false; + + // Set flavor if we find a match, or set it to unknown otherwise; in + // either case, return true to indicate that this is a select we can + // process. + Flavor = SPF_UNKNOWN; + if (auto *CmpI = dyn_cast(Cond)) + Flavor = matchDecomposedSelectPattern(CmpI, A, B, A, B).Flavor; + else + Flavor = SPF_UNKNOWN; + + return true; } -unsigned DenseMapInfo::getHashValue(SimpleValue Val) { +static unsigned getHashValueImpl(SimpleValue Val) { Instruction *Inst = Val.Inst; // Hash in all of the operands as pointers. if (BinaryOperator *BinOp = dyn_cast(Inst)) { @@ -168,40 +185,41 @@ return hash_combine(Inst->getOpcode(), Pred, LHS, RHS); } - // Hash min/max/abs (cmp + select) to allow for commuted operands. - // Min/max may also have non-canonical compare predicate (eg, the compare for - // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the - // compare. - Value *A, *B; - SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; - // TODO: We should also detect FP min/max. - if (SPF == SPF_SMIN || SPF == SPF_SMAX || - SPF == SPF_UMIN || SPF == SPF_UMAX) { - if (A > B) - std::swap(A, B); - return hash_combine(Inst->getOpcode(), SPF, A, B); - } - if (SPF == SPF_ABS || SPF == SPF_NABS) { - // ABS/NABS always puts the input in A and its negation in B. - return hash_combine(Inst->getOpcode(), SPF, A, B); - } - // Hash general selects to allow matching commuted true/false operands. - Value *Cond, *TVal, *FVal; - if (matchSelectWithOptionalNotCond(Inst, Cond, TVal, FVal)) { + SelectPatternFlavor SPF; + Value *Cond, *A, *B; + if (matchSelectWithOptionalNotCond(Inst, Cond, A, B, SPF)) { + // Hash min/max/abs (cmp + select) to allow for commuted operands. + // Min/max may also have non-canonical compare predicate (eg, the compare for + // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the + // compare. + // TODO: We should also detect FP min/max. + if (SPF == SPF_SMIN || SPF == SPF_SMAX || + SPF == SPF_UMIN || SPF == SPF_UMAX) { + if (A > B) + std::swap(A, B); + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + if (SPF == SPF_ABS || SPF == SPF_NABS) { + // ABS/NABS always puts the input in A and its negation in B. + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + + // Hash general selects to allow matching commuted true/false operands. + // If we do not have a compare as the condition, just hash in the condition. CmpInst::Predicate Pred; Value *X, *Y; if (!match(Cond, m_Cmp(Pred, m_Value(X), m_Value(Y)))) - return hash_combine(Inst->getOpcode(), Cond, TVal, FVal); + return hash_combine(Inst->getOpcode(), Cond, A, B); // Similar to cmp normalization (above) - canonicalize the predicate value: - // select (icmp Pred, X, Y), T, F --> select (icmp InvPred, X, Y), F, T + // select (icmp Pred, X, Y), A, B --> select (icmp InvPred, X, Y), B, A if (CmpInst::getInversePredicate(Pred) < Pred) { Pred = CmpInst::getInversePredicate(Pred); - std::swap(TVal, FVal); + std::swap(A, B); } - return hash_combine(Inst->getOpcode(), Pred, X, Y, TVal, FVal); + return hash_combine(Inst->getOpcode(), Pred, X, Y, A, B); } if (CastInst *CI = dyn_cast(Inst)) @@ -227,7 +245,19 @@ hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); } -bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { +unsigned DenseMapInfo::getHashValue(SimpleValue Val) { +#ifndef NDEBUG + // If -earlycse-debug-hash was specified, return a constant -- this + // will force all hashing to collide, so we'll exhaustively search + // the table for a match, and the assertion in isEqual will fire if + // there's a bug causing equal keys to hash differently. + if (EarlyCSEDebugHash) + return 0; +#endif + return getHashValueImpl(Val); +} + +static bool isEqualImpl(SimpleValue LHS, SimpleValue RHS) { Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; if (LHS.isSentinel() || RHS.isSentinel()) @@ -263,39 +293,47 @@ // Min/max/abs can occur with commuted operands, non-canonical predicates, // and/or non-canonical operands. - Value *LHSA, *LHSB; - SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor; - // TODO: We should also detect FP min/max. - if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || - LSPF == SPF_UMIN || LSPF == SPF_UMAX || - LSPF == SPF_ABS || LSPF == SPF_NABS) { - Value *RHSA, *RHSB; - SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; + // Selects can be non-trivially equivalent via inverted conditions and swaps. + SelectPatternFlavor LSPF, RSPF; + Value *CondL, *CondR, *LHSA, *RHSA, *LHSB, *RHSB; + if (matchSelectWithOptionalNotCond(LHSI, CondL, LHSA, LHSB, LSPF) && + matchSelectWithOptionalNotCond(RHSI, CondR, RHSA, RHSB, RSPF)) { if (LSPF == RSPF) { - // Abs results are placed in a defined order by matchSelectPattern. - if (LSPF == SPF_ABS || LSPF == SPF_NABS) + // TODO: We should also detect FP min/max. + if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || + LSPF == SPF_UMIN || LSPF == SPF_UMAX) + return ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA)); + + if (LSPF == SPF_ABS || LSPF == SPF_NABS) { + // Abs results are placed in a defined order by matchSelectPattern. return LHSA == RHSA && LHSB == RHSB; - return ((LHSA == RHSA && LHSB == RHSB) || - (LHSA == RHSB && LHSB == RHSA)); - } - } + } - // Selects can be non-trivially equivalent via inverted conditions and swaps. - Value *CondL, *CondR, *TrueL, *TrueR, *FalseL, *FalseR; - if (matchSelectWithOptionalNotCond(LHSI, CondL, TrueL, FalseL) && - matchSelectWithOptionalNotCond(RHSI, CondR, TrueR, FalseR)) { - // select Cond, T, F <--> select not(Cond), F, T - if (CondL == CondR && TrueL == TrueR && FalseL == FalseR) - return true; + // select Cond, A, B <--> select not(Cond), B, A + if (CondL == CondR && LHSA == RHSA && LHSB == RHSB) + return true; + } // If the true/false operands are swapped and the conditions are compares // with inverted predicates, the selects are equal: - // select (icmp Pred, X, Y), T, F <--> select (icmp InvPred, X, Y), F, T + // select (icmp Pred, X, Y), A, B <--> select (icmp InvPred, X, Y), B, A // - // This also handles patterns with a double-negation because we looked - // through a 'not' in the matching function and swapped T/F: - // select (cmp Pred, X, Y), T, F <--> select (not (cmp InvPred, X, Y)), T, F - if (TrueL == FalseR && FalseL == TrueR) { + // This also handles patterns with a double-negation in the sense of not + + // inverse, because we looked through a 'not' in the matching function and + // swapped A/B: + // select (cmp Pred, X, Y), A, B <--> select (not (cmp InvPred, X, Y)), B, A + // + // This intentionally does NOT handle patterns with a double-negation in + // the sense of not + not, because doing so could result in values + // comparing + // as equal that hash differently in the min/max/abs cases like: + // select (cmp slt, X, Y), X, Y <--> select (not (not (cmp slt, X, Y))), X, Y + // ^ hashes as min ^ would not hash as min + // In the context of the EarlyCSE pass, however, such cases never reach + // this code, as we simplify the double-negation before hashing the second + // select (and so still succeed at CSEing them). + if (LHSA == RHSB && LHSB == RHSA) { CmpInst::Predicate PredL, PredR; Value *X, *Y; if (match(CondL, m_Cmp(PredL, m_Value(X), m_Value(Y))) && @@ -308,6 +346,15 @@ return false; } +bool DenseMapInfo::isEqual(SimpleValue LHS, SimpleValue RHS) { + // These comparisons are nontrivial, so assert that equality implies + // hash equality (DenseMap demands this as an invariant). + bool Result = isEqualImpl(LHS, RHS); + assert(!Result || (LHS.isSentinel() && LHS.Inst == RHS.Inst) || + getHashValueImpl(LHS) == getHashValueImpl(RHS)); + return Result; +} + //===----------------------------------------------------------------------===// // CallValue //===----------------------------------------------------------------------===// diff --git a/llvm/test/Transforms/EarlyCSE/commute.ll b/llvm/test/Transforms/EarlyCSE/commute.ll --- a/llvm/test/Transforms/EarlyCSE/commute.ll +++ b/llvm/test/Transforms/EarlyCSE/commute.ll @@ -1,5 +1,5 @@ ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py -; RUN: opt < %s -S -early-cse | FileCheck %s +; RUN: opt < %s -S -early-cse -earlycse-debug-hash | FileCheck %s ; RUN: opt < %s -S -basicaa -early-cse-memssa | FileCheck %s define void @test1(float %A, float %B, float* %PA, float* %PB) { @@ -108,14 +108,13 @@ } ; Min/max can also have an inverted predicate and select operands. -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @smin_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @smin_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp slt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -155,13 +154,12 @@ ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @smax_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @smax_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp sgt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -203,13 +201,12 @@ ret <2 x i8> %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @umin_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @umin_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp ult i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp ult i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -250,13 +247,12 @@ ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i1 @umax_inverted(i8 %a, i8 %b) { ; CHECK-LABEL: @umax_inverted( ; CHECK-NEXT: [[CMP1:%.*]] = icmp ugt i8 [[A:%.*]], [[B:%.*]] ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[B]] -; CHECK: ret i1 +; CHECK-NEXT: ret i1 true ; %cmp1 = icmp ugt i8 %a, %b %cmp2 = xor i1 %cmp1, -1 @@ -302,14 +298,13 @@ ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i8 @abs_inverted(i8 %a) { ; CHECK-LABEL: @abs_inverted( ; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]] ; CHECK-NEXT: [[CMP1:%.*]] = icmp sgt i8 [[A]], 0 ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]] -; CHECK: ret i8 +; CHECK-NEXT: ret i8 [[M1]] ; %neg = sub i8 0, %a %cmp1 = icmp sgt i8 %a, 0 @@ -337,14 +332,13 @@ ret i8 %r } -; TODO: Ensure we always recognize this (currently depends on hash collision) define i8 @nabs_inverted(i8 %a) { ; CHECK-LABEL: @nabs_inverted( ; CHECK-NEXT: [[NEG:%.*]] = sub i8 0, [[A:%.*]] ; CHECK-NEXT: [[CMP1:%.*]] = icmp slt i8 [[A]], 0 ; CHECK-NEXT: [[CMP2:%.*]] = xor i1 [[CMP1]], true ; CHECK-NEXT: [[M1:%.*]] = select i1 [[CMP1]], i8 [[A]], i8 [[NEG]] -; CHECK: ret i8 +; CHECK-NEXT: ret i8 0 ; %neg = sub i8 0, %a %cmp1 = icmp slt i8 %a, 0 @@ -646,3 +640,36 @@ %r = sub i32 %m2, %m1 ret i32 %r } + + +; This test is a reproducer for a bug involving inverted min/max selects +; hashing differently but comparing as equal. It exhibits such a pair of +; values, and we run this test with -earlycse-debug-hash which would catch +; the disagreement and fail if it regressed. This test also includes a +; negation of each negation to check for the same issue one level deeper. +define void @not_not_min(i32* %px, i32* %py, i32* %pout) { +; CHECK-LABEL: @not_not_min( +; CHECK-NEXT: [[X:%.*]] = load volatile i32, i32* [[PX:%.*]] +; CHECK-NEXT: [[Y:%.*]] = load volatile i32, i32* [[PY:%.*]] +; CHECK-NEXT: [[CMPA:%.*]] = icmp slt i32 [[X]], [[Y]] +; CHECK-NEXT: [[CMPB:%.*]] = xor i1 [[CMPA]], true +; CHECK-NEXT: [[RA:%.*]] = select i1 [[CMPA]], i32 [[X]], i32 [[Y]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT:%.*]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]] +; CHECK-NEXT: store volatile i32 [[RA]], i32* [[POUT]] +; CHECK-NEXT: ret void +; + %x = load volatile i32, i32* %px + %y = load volatile i32, i32* %py + %cmpa = icmp slt i32 %x, %y + %cmpb = xor i1 %cmpa, -1 + %cmpc = xor i1 %cmpb, -1 + %ra = select i1 %cmpa, i32 %x, i32 %y + %rb = select i1 %cmpb, i32 %y, i32 %x + %rc = select i1 %cmpc, i32 %x, i32 %y + store volatile i32 %ra, i32* %pout + store volatile i32 %rb, i32* %pout + store volatile i32 %rc, i32* %pout + + ret void +}