diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -3436,7 +3436,6 @@ return SDValue(); } - auto cancelDiamond = [&](SDValue A,SDValue B) { SDLoc DL(N); SDValue NewY = @@ -3507,49 +3506,84 @@ SDValue Carry0 = getAsCarry(TLI, N0); if (!Carry0) return SDValue(); + SDValue Carry1 = getAsCarry(TLI, N1); if (!Carry1) return SDValue(); - unsigned Opcode = Carry0.getOpcode(); - if (Opcode != Carry1.getOpcode()) + auto matchCarry = [](SDValue N, SDValue &A, SDValue &B, bool &IsAdd) { + unsigned Opcode = N.getOpcode(); + IsAdd = Opcode == ISD::UADDO || Opcode == ISD::UADDO_CARRY; + + if (Opcode == ISD::UADDO || Opcode == ISD::USUBO) { + A = N.getOperand(0); + B = N.getOperand(1); + return true; + } + + if ((Opcode == ISD::UADDO_CARRY || Opcode == ISD::USUBO_CARRY) && + isNullConstant(N.getOperand(1))) { + A = N.getOperand(0); + B = N.getOperand(2); + return true; + } + + return false; + }; + + SDValue X, Y; + bool IsAdd; + if (!matchCarry(Carry0, X, Y, IsAdd)) return SDValue(); - if (Opcode != ISD::UADDO && Opcode != ISD::USUBO) + + SDValue A, B; + bool IsAdd1; + if (!matchCarry(Carry1, A, B, IsAdd1)) + return SDValue(); + + if (IsAdd != IsAdd1) + return SDValue(); + + unsigned NewOp = IsAdd ? ISD::UADDO_CARRY : ISD::USUBO_CARRY; + if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType())) return SDValue(); // Canonicalize the add/sub of A and B (the top node in the above ASCII art) // as Carry0 and the add/sub of the carry in as Carry1 (the middle node). - if (Carry1.getNode()->isOperandOf(Carry0.getNode())) + if (Carry0.getValue(0) == A || Carry0.getValue(0) == B) { std::swap(Carry0, Carry1); + std::swap(A, X); + std::swap(B, Y); + } // Check if nodes are connected in expected way. - if (Carry1.getOperand(0) != Carry0.getValue(0) && - Carry1.getOperand(1) != Carry0.getValue(0)) - return SDValue(); - - // The carry in value must be on the righthand side for subtraction. - unsigned CarryInOperandNum = - Carry1.getOperand(0) == Carry0.getValue(0) ? 1 : 0; - if (Opcode == ISD::USUBO && CarryInOperandNum != 1) - return SDValue(); - SDValue CarryIn = Carry1.getOperand(CarryInOperandNum); - - unsigned NewOp = Opcode == ISD::UADDO ? ISD::UADDO_CARRY : ISD::USUBO_CARRY; - if (!TLI.isOperationLegalOrCustom(NewOp, Carry0.getValue(0).getValueType())) + if (Carry1.getValue(0) != X && (!IsAdd || Carry1.getValue(0) != Y)) return SDValue(); // Verify that the carry/borrow in is plausibly a carry/borrow bit. // TODO: make getAsCarry() aware of how partial carries are merged. - if (CarryIn.getOpcode() != ISD::ZERO_EXTEND) - return SDValue(); - CarryIn = CarryIn.getOperand(0); + SDValue CarryIn = Carry1.getValue(0) == X ? Y : X; + if (CarryIn.getOpcode() == ISD::ZERO_EXTEND) + CarryIn = CarryIn.getOperand(0); + if (CarryIn.getValueType() != MVT::i1) return SDValue(); SDLoc DL(N); SDValue Merged = - DAG.getNode(NewOp, DL, Carry1->getVTList(), Carry0.getOperand(0), - Carry0.getOperand(1), CarryIn); + DAG.getNode(NewOp, DL, Carry1->getVTList(), A, B, CarryIn); + + LLVM_DEBUG( + dbgs() << "Combine carry diamond:\n"; + dbgs() << "\tCarry0: "; Carry0->dump(); + dbgs() << "\tX: "; X->dump(); + dbgs() << "\tY: "; Y->dump(); + dbgs() << "\tCarry1: "; Carry1->dump(); + dbgs() << "\tA: "; A->dump(); + dbgs() << "\tB: "; B->dump(); + dbgs() << "\tCarryIn: "; CarryIn->dump(); + dbgs() << "\tMerged: "; Merged->dump(); + dbgs() << "\n"); // Please note that because we have proven that the result of the UADDO/USUBO // of A and B feeds into the UADDO/USUBO that does the carry/borrow in, we can @@ -3564,7 +3598,7 @@ // carry flags; and that AND can return a constant zero. // // TODO: match other operations that can merge flags (ADD, etc) - DAG.ReplaceAllUsesOfValueWith(Carry1.getValue(0), Merged.getValue(0)); + DAG.ReplaceAllUsesOfValueWith(Carry0.getValue(0), Merged.getValue(0)); if (N->getOpcode() == ISD::AND) return DAG.getConstant(0, DL, MVT::i1); return Merged.getValue(1);