Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -17446,16 +17446,7 @@ return SDValue(); } -// Optimize CSEL instructions -static SDValue performCSELCombine(SDNode *N, - TargetLowering::DAGCombinerInfo &DCI, - SelectionDAG &DAG) { - // CSEL x, x, cc -> x - if (N->getOperand(0) == N->getOperand(1)) - return N->getOperand(0); - - // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1 - // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1 +static SDValue foldCTTZ(SDNode *N, SelectionDAG &DAG) { unsigned CC = N->getConstantOperandVal(2); SDValue SUBS = N->getOperand(3); SDValue Zero, CTTZ; @@ -17476,6 +17467,9 @@ (CTTZ.getOpcode() == ISD::CTTZ || (CTTZ.getOpcode() == ISD::TRUNCATE && CTTZ.getOperand(0).getOpcode() == ISD::CTTZ))) { + assert( + (CTTZ.getValueType() == MVT::i32 || CTTZ.getValueType() == MVT::i64) && + "Illegal type in CTTZ folding"); if (isNullConstant(Zero) && isNullConstant(SUBS.getValue(1).getOperand(1))) { SDValue X = CTTZ.getOpcode() == ISD::TRUNCATE ? CTTZ.getOperand(0).getOperand(0) : CTTZ.getOperand(0); @@ -17487,6 +17481,23 @@ } } + return SDValue(); +} + +// Optimize CSEL instructions +static SDValue performCSELCombine(SDNode *N, + TargetLowering::DAGCombinerInfo &DCI, + SelectionDAG &DAG) { + // CSEL x, x, cc -> x + if (N->getOperand(0) == N->getOperand(1)) + return N->getOperand(0); + + // CSEL 0, cttz(X), eq(X, 0) -> AND cttz bitwidth-1 + // CSEL cttz(X), 0, ne(X, 0) -> AND cttz bitwidth-1 + SDValue Folded = foldCTTZ(N, DAG); + if (Folded.getNode() != nullptr) + return Folded; + return performCONDCombine(N, DCI, DAG, 2, 3); }