diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -2521,26 +2521,31 @@ return isKnownNonZero(V, DemandedElts, Depth, Q); } -/// If the pair of operators are the same invertible function of a single -/// operand return the index of that operand. Otherwise, return None. An -/// invertible function is one that is 1-to-1 and maps every input value -/// to exactly one output value. This is equivalent to saying that Op1 -/// and Op2 are equal exactly when the specified pair of operands are equal, -/// (except that Op1 and Op2 may be poison more often.) -static Optional getInvertibleOperand(const Operator *Op1, - const Operator *Op2) { +/// If the pair of operators are the same invertible function, return the +/// the operands of the function corresponding to each input. Otherwise, +/// return None. An invertible function is one that is 1-to-1 and maps +/// every input value to exactly one output value. This is equivalent to +/// saying that Op1 and Op2 are equal exactly when the specified pair of +/// operands are equal, (except that Op1 and Op2 may be poison more often.) +static Optional> +getInvertibleOperands(const Operator *Op1, + const Operator *Op2) { if (Op1->getOpcode() != Op2->getOpcode()) return None; + auto getOperands = [&](unsigned OpNum) -> auto { + return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum)); + }; + switch (Op1->getOpcode()) { default: break; case Instruction::Add: case Instruction::Sub: if (Op1->getOperand(0) == Op2->getOperand(0)) - return 1; + return getOperands(1); if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; case Instruction::Mul: { // invertible if A * B == (A * B) mod 2^N where A, and B are integers @@ -2556,7 +2561,7 @@ if (Op1->getOperand(1) == Op2->getOperand(1) && isa(Op1->getOperand(1)) && !cast(Op1->getOperand(1))->isZero()) - return 0; + return getOperands(0); break; } case Instruction::Shl: { @@ -2569,7 +2574,7 @@ break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; } case Instruction::AShr: @@ -2580,13 +2585,13 @@ break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; } case Instruction::SExt: case Instruction::ZExt: if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType()) - return 0; + return getOperands(0); break; case Instruction::PHI: { const PHINode *PN1 = cast(Op1); @@ -2604,19 +2609,20 @@ !matchSimpleRecurrence(PN2, BO2, Start2, Step2)) break; - Optional Idx = getInvertibleOperand(cast(BO1), - cast(BO2)); - if (!Idx || *Idx != 0) - break; - if (BO1->getOperand(*Idx) != PN1 || BO2->getOperand(*Idx) != PN2) + auto Values = getInvertibleOperands(cast(BO1), + cast(BO2)); + if (!Values) + break; + + // We have to be careful of mutually defined recurrences here. Ex: + // * X_i = X_(i-1) OP Y_(i-1), and Y_i = X_(i-1) OP V + // * X_i = Y_i = X_(i-1) OP Y_(i-1) + // The invertibility of these is complicated, and not worth reasoning + // about (yet?). + if (Values->first != PN1 || Values->second != PN2) break; - // Phi operands might not be in the same order. TODO: generalize - // interface to return pair of operands. - if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2) - return 1; - if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2) - return 0; + return std::make_pair(Start1, Start2); } } return None; @@ -2713,11 +2719,9 @@ auto *O1 = dyn_cast(V1); auto *O2 = dyn_cast(V2); if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) { - if (Optional Opt = getInvertibleOperand(O1, O2)) { - unsigned Idx = *Opt; - return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx), - Depth + 1, Q); - } + if (auto Values = getInvertibleOperands(O1, O2)) + return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q); + if (const PHINode *PN1 = dyn_cast(V1)) { const PHINode *PN2 = cast(V2); // FIXME: This is missing a generalization to handle the case where one is diff --git a/llvm/test/Analysis/ValueTracking/known-non-equal.ll b/llvm/test/Analysis/ValueTracking/known-non-equal.ll --- a/llvm/test/Analysis/ValueTracking/known-non-equal.ll +++ b/llvm/test/Analysis/ValueTracking/known-non-equal.ll @@ -736,8 +736,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1 @@ -808,8 +807,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1 @@ -843,8 +841,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1 @@ -979,8 +976,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1