Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17449,39 +17449,39 @@ static SDValue foldCTTZ(SDNode *N, SelectionDAG &DAG) { unsigned CC = N->getConstantOperandVal(2); SDValue SUBS = N->getOperand(3); - SDValue Zero, CTTZ; - bool IsEQOrNE = false; + SDValue Zero, CTTZ, AND; if (CC == AArch64CC::EQ && SUBS.getOpcode() == AArch64ISD::SUBS) { Zero = N->getOperand(0); CTTZ = N->getOperand(1); - IsEQOrNE = true; } else if (CC == AArch64CC::NE && SUBS.getOpcode() == AArch64ISD::SUBS) { Zero = N->getOperand(1); CTTZ = N->getOperand(0); - IsEQOrNE = true; - } - - if (IsEQOrNE && - (CTTZ.getOpcode() == ISD::CTTZ || - (CTTZ.getOpcode() == ISD::TRUNCATE && - CTTZ.getOperand(0).getOpcode() == ISD::CTTZ))) { - assert( - (CTTZ.getValueType() == MVT::i32 || CTTZ.getValueType() == MVT::i64) && - "Illegal type in CTTZ folding"); - if (isNullConstant(Zero) && - isNullConstant(SUBS.getValue(1).getOperand(1))) { - SDValue X = CTTZ.getOpcode() == ISD::TRUNCATE ? CTTZ.getOperand(0).getOperand(0) : CTTZ.getOperand(0); - if(X == SUBS.getOperand(0)) { - unsigned BitWidth = CTTZ.getValueSizeInBits(); - SDValue BitWidthMinusOne = DAG.getConstant(BitWidth-1, SDLoc(N), CTTZ.getValueType()); - return DAG.getNode(ISD::AND, SDLoc(N), CTTZ.getValueType(), CTTZ, BitWidthMinusOne); - } + } else + return SDValue(); + + if ((CTTZ.getOpcode() != ISD::CTTZ && CTTZ.getOpcode() != ISD::TRUNCATE) || + (CTTZ.getOpcode() == ISD::TRUNCATE && + CTTZ.getOperand(0).getOpcode() != ISD::CTTZ)) + return SDValue(); + + assert((CTTZ.getValueType() == MVT::i32 || CTTZ.getValueType() == MVT::i64) && + "Illegal type in CTTZ folding"); + if (isNullConstant(Zero) && isNullConstant(SUBS.getValue(1).getOperand(1))) { + SDValue X = CTTZ.getOpcode() == ISD::TRUNCATE + ? CTTZ.getOperand(0).getOperand(0) + : CTTZ.getOperand(0); + if (X == SUBS.getOperand(0)) { + unsigned BitWidth = CTTZ.getValueSizeInBits(); + SDValue BitWidthMinusOne = + DAG.getConstant(BitWidth - 1, SDLoc(N), CTTZ.getValueType()); + AND = DAG.getNode(ISD::AND, SDLoc(N), CTTZ.getValueType(), CTTZ, + BitWidthMinusOne); } } - return SDValue(); + return AND; } // Optimize CSEL instructions @@ -17494,8 +17494,7 @@ // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1 // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1 - SDValue Folded = foldCTTZ(N, DAG); - if (Folded.getNode() != nullptr) + if (SDValue Folded = foldCTTZ(N, DAG)) return Folded; return performCONDCombine(N, DCI, DAG, 2, 3);