Index: llvm/lib/Analysis/ValueTracking.cpp =================================================================== --- llvm/lib/Analysis/ValueTracking.cpp +++ llvm/lib/Analysis/ValueTracking.cpp @@ -2513,26 +2513,31 @@ return isKnownNonZero(V, DemandedElts, Depth, Q); } -/// If the pair of operators are the same invertible function of a single -/// operand return the index of that operand. Otherwise, return None. An -/// invertible function is one that is 1-to-1 and maps every input value -/// to exactly one output value. This is equivalent to saying that Op1 -/// and Op2 are equal exactly when the specified pair of operands are equal, -/// (except that Op1 and Op2 may be poison more often.) -static Optional getInvertibleOperand(const Operator *Op1, - const Operator *Op2) { +/// If the pair of operators are the same invertible function, return the +/// the operands of the function corresponding to each input. Otherwise, +/// return None. An invertible function is one that is 1-to-1 and maps +/// every input value to exactly one output value. This is equivalent to +/// saying that Op1 and Op2 are equal exactly when the specified pair of +/// operands are equal, (except that Op1 and Op2 may be poison more often.) +static Optional> +getInvertibleOperand(const Operator *Op1, + const Operator *Op2) { if (Op1->getOpcode() != Op2->getOpcode()) return None; + auto getOperands = [&](unsigned OpNum) -> auto { + return std::make_pair(Op1->getOperand(OpNum), Op2->getOperand(OpNum)); + }; + switch (Op1->getOpcode()) { default: break; case Instruction::Add: case Instruction::Sub: if (Op1->getOperand(0) == Op2->getOperand(0)) - return 1; + return getOperands(1); if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; case Instruction::Mul: { // invertible if A * B == (A * B) mod 2^N where A, and B are integers @@ -2548,7 +2553,7 @@ if (Op1->getOperand(1) == Op2->getOperand(1) && isa(Op1->getOperand(1)) && !cast(Op1->getOperand(1))->isZero()) - return 0; + return getOperands(0); break; } case Instruction::Shl: { @@ -2561,7 +2566,7 @@ break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; } case Instruction::AShr: @@ -2572,13 +2577,13 @@ break; if (Op1->getOperand(1) == Op2->getOperand(1)) - return 0; + return getOperands(0); break; } case Instruction::SExt: case Instruction::ZExt: if (Op1->getOperand(0)->getType() == Op2->getOperand(0)->getType()) - return 0; + return getOperands(0); break; case Instruction::PHI: { const PHINode *PN1 = cast(Op1); @@ -2596,18 +2601,12 @@ !matchSimpleRecurrence(PN2, BO2, Start2, Step2)) break; - Optional Idx = getInvertibleOperand(cast(BO1), - cast(BO2)); - if (!Idx || *Idx != 0) + auto Values = getInvertibleOperand(cast(BO1), + cast(BO2)); + if (!Values) break; - assert(BO1->getOperand(*Idx) == PN1 && BO2->getOperand(*Idx) == PN2); - - // Phi operands might not be in the same order. TODO: generalize - // interface to return pair of operands. - if (PN1->getOperand(0) == BO1 && PN2->getOperand(0) == BO2) - return 1; - if (PN1->getOperand(1) == BO1 && PN2->getOperand(1) == BO2) - return 0; + assert(Values->first == PN1 && Values->second == PN2); + return std::make_pair(Start1, Start2); } } return None; @@ -2704,11 +2703,9 @@ auto *O1 = dyn_cast(V1); auto *O2 = dyn_cast(V2); if (O1 && O2 && O1->getOpcode() == O2->getOpcode()) { - if (Optional Opt = getInvertibleOperand(O1, O2)) { - unsigned Idx = *Opt; - return isKnownNonEqual(O1->getOperand(Idx), O2->getOperand(Idx), - Depth + 1, Q); - } + if (auto Values = getInvertibleOperand(O1, O2)) + return isKnownNonEqual(Values->first, Values->second, Depth + 1, Q); + if (const PHINode *PN1 = dyn_cast(V1)) { const PHINode *PN2 = cast(V2); // FIXME: This is missing a generalization to handle the case where one is Index: llvm/test/Analysis/ValueTracking/known-non-equal.ll =================================================================== --- llvm/test/Analysis/ValueTracking/known-non-equal.ll +++ llvm/test/Analysis/ValueTracking/known-non-equal.ll @@ -738,8 +738,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1 @@ -810,8 +809,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1 @@ -845,8 +843,7 @@ ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i64 [[IV_NEXT]], 10 ; CHECK-NEXT: br i1 [[CMP]], label [[LOOP]], label [[EXIT:%.*]] ; CHECK: exit: -; CHECK-NEXT: [[RES:%.*]] = icmp eq i8 [[A_IV]], [[B_IV]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: ret i1 false ; entry: %B = add i8 %A, 1