diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -6101,6 +6101,14 @@ return Ok; } +// Check that (every element of) Z is undef or not an exact multiple of BW. +static bool isNonZeroModBitWidth(SDValue Z, unsigned BW) { + return ISD::matchUnaryPredicate( + Z, + [=](ConstantSDNode *C) { return !C || C->getAPIntValue().urem(BW) != 0; }, + true); +} + bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result, SelectionDAG &DAG) const { EVT VT = Node->getValueType(0); @@ -6111,40 +6119,52 @@ !isOperationLegalOrCustomOrPromote(ISD::OR, VT))) return false; - // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW)) - // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW) SDValue X = Node->getOperand(0); SDValue Y = Node->getOperand(1); SDValue Z = Node->getOperand(2); - unsigned EltSizeInBits = VT.getScalarSizeInBits(); + unsigned BW = VT.getScalarSizeInBits(); bool IsFSHL = Node->getOpcode() == ISD::FSHL; SDLoc DL(SDValue(Node, 0)); EVT ShVT = Z.getValueType(); - SDValue Mask = DAG.getConstant(EltSizeInBits - 1, DL, ShVT); - SDValue ShAmt, InvShAmt; - if (isPowerOf2_32(EltSizeInBits)) { - // Z % BW -> Z & (BW - 1) - ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); - // (BW - 1) - (Z % BW) -> ~Z & (BW - 1) - InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask); - } else { - SDValue BitWidthC = DAG.getConstant(EltSizeInBits, DL, ShVT); - ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); - InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt); - } - SDValue One = DAG.getConstant(1, DL, ShVT); SDValue ShX, ShY; - if (IsFSHL) { - ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt); - SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One); - ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt); + SDValue ShAmt, InvShAmt; + if (isNonZeroModBitWidth(Z, BW)) { + // fshl: X << C | Y >> (BW - C) + // fshr: X << (BW - C) | Y >> C + // where C = Z % BW is not zero + SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT); + ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); + InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, BitWidthC, ShAmt); + ShX = DAG.getNode(ISD::SHL, DL, VT, X, IsFSHL ? ShAmt : InvShAmt); + ShY = DAG.getNode(ISD::SRL, DL, VT, Y, IsFSHL ? InvShAmt : ShAmt); } else { - SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One); - ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt); - ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt); + // fshl: X << (Z % BW) | Y >> 1 >> (BW - 1 - (Z % BW)) + // fshr: X << 1 << (BW - 1 - (Z % BW)) | Y >> (Z % BW) + SDValue Mask = DAG.getConstant(BW - 1, DL, ShVT); + if (isPowerOf2_32(BW)) { + // Z % BW -> Z & (BW - 1) + ShAmt = DAG.getNode(ISD::AND, DL, ShVT, Z, Mask); + // (BW - 1) - (Z % BW) -> ~Z & (BW - 1) + InvShAmt = DAG.getNode(ISD::AND, DL, ShVT, DAG.getNOT(DL, Z, ShVT), Mask); + } else { + SDValue BitWidthC = DAG.getConstant(BW, DL, ShVT); + ShAmt = DAG.getNode(ISD::UREM, DL, ShVT, Z, BitWidthC); + InvShAmt = DAG.getNode(ISD::SUB, DL, ShVT, Mask, ShAmt); + } + + SDValue One = DAG.getConstant(1, DL, ShVT); + if (IsFSHL) { + ShX = DAG.getNode(ISD::SHL, DL, VT, X, ShAmt); + SDValue ShY1 = DAG.getNode(ISD::SRL, DL, VT, Y, One); + ShY = DAG.getNode(ISD::SRL, DL, VT, ShY1, InvShAmt); + } else { + SDValue ShX1 = DAG.getNode(ISD::SHL, DL, VT, X, One); + ShX = DAG.getNode(ISD::SHL, DL, VT, ShX1, InvShAmt); + ShY = DAG.getNode(ISD::SRL, DL, VT, Y, ShAmt); + } } Result = DAG.getNode(ISD::OR, DL, VT, ShX, ShY); return true;