diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6741,46 +6741,28 @@ assert(Op.getOpcode() == ISD::SRA_PARTS || Op.getOpcode() == ISD::SRL_PARTS); - SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, - DAG.getConstant(VTBits, dl, MVT::i64), ShAmt); - SDValue HiBitsForLo = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, RevShAmt); - - // Unfortunately, if ShAmt == 0, we just calculated "(SHL ShOpHi, 64)" which - // is "undef". We wanted 0, so CSEL it directly. - SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64), - ISD::SETEQ, dl, DAG); - SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32); - HiBitsForLo = - DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), - HiBitsForLo, CCVal, Cmp); - - SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt, - DAG.getConstant(VTBits, dl, MVT::i64)); - - SDValue LoBitsForLo = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, ShAmt); - SDValue LoForNormalShift = - DAG.getNode(ISD::OR, dl, VT, LoBitsForLo, HiBitsForLo); - - Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE, - dl, DAG); - CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); - SDValue LoForBigShift = DAG.getNode(Opc, dl, VT, ShOpHi, ExtraShAmt); - SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift, - LoForNormalShift, CCVal, Cmp); - - // AArch64 shifts larger than the register width are wrapped rather than - // clamped, so we can't just emit "hi >> x". - SDValue HiForNormalShift = DAG.getNode(Opc, dl, VT, ShOpHi, ShAmt); - SDValue HiForBigShift = - Opc == ISD::SRA - ? DAG.getNode(Opc, dl, VT, ShOpHi, - DAG.getConstant(VTBits - 1, dl, MVT::i64)) - : DAG.getConstant(0, dl, VT); - SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift, - HiForNormalShift, CCVal, Cmp); - - SDValue Ops[2] = { Lo, Hi }; - return DAG.getMergeValues(Ops, dl); + // Compute the shift with the shift amount masked. + SDValue MaskedShAmt = DAG.getNode(ISD::AND, dl, MVT::i64, ShAmt, + DAG.getConstant(VTBits - 1, dl, MVT::i64)); + SDValue Hi = DAG.getNode(Opc, dl, MVT::i64, ShOpHi, MaskedShAmt); + SDValue Lo = DAG.getNode(ISD::FSHR, dl, MVT::i64, ShOpHi, ShOpLo, ShAmt); + + // Check if the shift is large. + SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(VTBits, dl, MVT::i64), + ISD::SETGE, dl, DAG); + // Adjust the results. + SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); + Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, Hi, Lo, CCVal, Cmp); + if (Op.getOpcode() == ISD::SRA_PARTS) { + Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, + DAG.getNode(ISD::SRA, dl, MVT::i64, ShOpHi, + DAG.getConstant(VTBits - 1, dl, MVT::i64)), + Hi, CCVal, Cmp); + } else { + Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), + Hi, CCVal, Cmp); + } + return DAG.getMergeValues({Lo, Hi}, dl); } /// LowerShiftLeftParts - Lower SHL_PARTS, which returns two @@ -6796,42 +6778,22 @@ SDValue ShAmt = Op.getOperand(2); assert(Op.getOpcode() == ISD::SHL_PARTS); - SDValue RevShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, - DAG.getConstant(VTBits, dl, MVT::i64), ShAmt); - SDValue LoBitsForHi = DAG.getNode(ISD::SRL, dl, VT, ShOpLo, RevShAmt); - - // Unfortunately, if ShAmt == 0, we just calculated "(SRL ShOpLo, 64)" which - // is "undef". We wanted 0, so CSEL it directly. - SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(0, dl, MVT::i64), - ISD::SETEQ, dl, DAG); - SDValue CCVal = DAG.getConstant(AArch64CC::EQ, dl, MVT::i32); - LoBitsForHi = - DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), - LoBitsForHi, CCVal, Cmp); - - SDValue ExtraShAmt = DAG.getNode(ISD::SUB, dl, MVT::i64, ShAmt, - DAG.getConstant(VTBits, dl, MVT::i64)); - SDValue HiBitsForHi = DAG.getNode(ISD::SHL, dl, VT, ShOpHi, ShAmt); - SDValue HiForNormalShift = - DAG.getNode(ISD::OR, dl, VT, LoBitsForHi, HiBitsForHi); - - SDValue HiForBigShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ExtraShAmt); - - Cmp = emitComparison(ExtraShAmt, DAG.getConstant(0, dl, MVT::i64), ISD::SETGE, - dl, DAG); - CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); - SDValue Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, HiForBigShift, - HiForNormalShift, CCVal, Cmp); - - // AArch64 shifts of larger than register sizes are wrapped rather than - // clamped, so we can't just emit "lo << a" if a is too big. - SDValue LoForBigShift = DAG.getConstant(0, dl, VT); - SDValue LoForNormalShift = DAG.getNode(ISD::SHL, dl, VT, ShOpLo, ShAmt); - SDValue Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, LoForBigShift, - LoForNormalShift, CCVal, Cmp); - - SDValue Ops[2] = { Lo, Hi }; - return DAG.getMergeValues(Ops, dl); + + // Compute the shift with the shift amount masked. + SDValue MaskedShAmt = DAG.getNode(ISD::AND, dl, MVT::i64, ShAmt, + DAG.getConstant(VTBits - 1, dl, MVT::i64)); + SDValue Lo = DAG.getNode(ISD::SHL, dl, MVT::i64, ShOpLo, MaskedShAmt); + SDValue Hi = DAG.getNode(ISD::FSHL, dl, MVT::i64, ShOpHi, ShOpLo, ShAmt); + + // Check if the shift is large. + SDValue Cmp = emitComparison(ShAmt, DAG.getConstant(VTBits, dl, MVT::i64), + ISD::SETGE, dl, DAG); + // Adjust the results. + SDValue CCVal = DAG.getConstant(AArch64CC::GE, dl, MVT::i32); + Hi = DAG.getNode(AArch64ISD::CSEL, dl, VT, Lo, Hi, CCVal, Cmp); + Lo = DAG.getNode(AArch64ISD::CSEL, dl, VT, DAG.getConstant(0, dl, MVT::i64), + Lo, CCVal, Cmp); + return DAG.getMergeValues({Lo, Hi}, dl); } bool AArch64TargetLowering::isOffsetFoldingLegal( diff --git a/llvm/test/CodeGen/AArch64/arm64-long-shift.ll b/llvm/test/CodeGen/AArch64/arm64-long-shift.ll --- a/llvm/test/CodeGen/AArch64/arm64-long-shift.ll +++ b/llvm/test/CodeGen/AArch64/arm64-long-shift.ll @@ -4,15 +4,13 @@ define i128 @shl(i128 %r, i128 %s) nounwind readnone { ; CHECK-LABEL: shl: ; CHECK: // %bb.0: -; CHECK-NEXT: neg x8, x2 -; CHECK-NEXT: lsr x8, x0, x8 -; CHECK-NEXT: cmp x2, #0 // =0 -; CHECK-NEXT: csel x8, xzr, x8, eq -; CHECK-NEXT: lsl x9, x1, x2 +; CHECK-NEXT: lsl x8, x1, x2 +; CHECK-NEXT: mvn w9, w2 +; CHECK-NEXT: lsr x10, x0, #1 +; CHECK-NEXT: lsr x9, x10, x9 ; CHECK-NEXT: orr x8, x8, x9 ; CHECK-NEXT: lsl x9, x0, x2 -; CHECK-NEXT: sub x10, x2, #64 // =64 -; CHECK-NEXT: cmp x10, #0 // =0 +; CHECK-NEXT: cmp x2, #64 // =64 ; CHECK-NEXT: csel x1, x9, x8, ge ; CHECK-NEXT: csel x0, xzr, x9, ge ; CHECK-NEXT: ret @@ -39,15 +37,13 @@ define i128 @ashr(i128 %r, i128 %s) nounwind readnone { ; CHECK-LABEL: ashr: ; CHECK: // %bb.0: -; CHECK-NEXT: neg x8, x2 -; CHECK-NEXT: lsl x8, x1, x8 -; CHECK-NEXT: cmp x2, #0 // =0 -; CHECK-NEXT: csel x8, xzr, x8, eq -; CHECK-NEXT: lsr x9, x0, x2 +; CHECK-NEXT: lsr x8, x0, x2 +; CHECK-NEXT: mvn w9, w2 +; CHECK-NEXT: lsl x10, x1, #1 +; CHECK-NEXT: lsl x9, x10, x9 ; CHECK-NEXT: orr x8, x9, x8 ; CHECK-NEXT: asr x9, x1, x2 -; CHECK-NEXT: sub x10, x2, #64 // =64 -; CHECK-NEXT: cmp x10, #0 // =0 +; CHECK-NEXT: cmp x2, #64 // =64 ; CHECK-NEXT: csel x0, x9, x8, ge ; CHECK-NEXT: asr x8, x1, #63 ; CHECK-NEXT: csel x1, x8, x9, ge @@ -75,15 +71,13 @@ define i128 @lshr(i128 %r, i128 %s) nounwind readnone { ; CHECK-LABEL: lshr: ; CHECK: // %bb.0: -; CHECK-NEXT: neg x8, x2 -; CHECK-NEXT: lsl x8, x1, x8 -; CHECK-NEXT: cmp x2, #0 // =0 -; CHECK-NEXT: csel x8, xzr, x8, eq -; CHECK-NEXT: lsr x9, x0, x2 +; CHECK-NEXT: lsr x8, x0, x2 +; CHECK-NEXT: mvn w9, w2 +; CHECK-NEXT: lsl x10, x1, #1 +; CHECK-NEXT: lsl x9, x10, x9 ; CHECK-NEXT: orr x8, x9, x8 ; CHECK-NEXT: lsr x9, x1, x2 -; CHECK-NEXT: sub x10, x2, #64 // =64 -; CHECK-NEXT: cmp x10, #0 // =0 +; CHECK-NEXT: cmp x2, #64 // =64 ; CHECK-NEXT: csel x0, x9, x8, ge ; CHECK-NEXT: csel x1, xzr, x9, ge ; CHECK-NEXT: ret