@@ -2595,27 +2595,49 @@ SDValue AMDGPUTargetLowering::splitBinaryBitConstantOpImpl(
2595
2595
2596
2596
SDValue AMDGPUTargetLowering::performShlCombine (SDNode *N,
2597
2597
DAGCombinerInfo &DCI) const {
2598
- if (N->getValueType (0 ) != MVT::i64)
2598
+ EVT VT = N->getValueType (0 );
2599
+ if (VT != MVT::i64)
2599
2600
return SDValue ();
2600
2601
2601
- // i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
2602
-
2603
- // On some subtargets, 64-bit shift is a quarter rate instruction. In the
2604
- // common case, splitting this into a move and a 32-bit shift is faster and
2605
- // the same code size.
2606
- const ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
2602
+ ConstantSDNode *RHS = dyn_cast<ConstantSDNode>(N->getOperand (1 ));
2607
2603
if (!RHS)
2608
2604
return SDValue ();
2609
2605
2610
- unsigned RHSVal = RHS->getZExtValue ();
2611
- if (RHSVal < 32 )
2612
- return SDValue ();
2613
-
2614
2606
SDValue LHS = N->getOperand (0 );
2607
+ unsigned RHSVal = RHS->getZExtValue ();
2608
+ if (!RHSVal)
2609
+ return LHS;
2615
2610
2616
2611
SDLoc SL (N);
2617
2612
SelectionDAG &DAG = DCI.DAG ;
2618
2613
2614
+ switch (LHS->getOpcode ()) {
2615
+ default :
2616
+ break ;
2617
+ case ISD::ZERO_EXTEND:
2618
+ case ISD::SIGN_EXTEND:
2619
+ case ISD::ANY_EXTEND: {
2620
+ // shl (ext x) => zext (shl x), if shift does not overflow int
2621
+ KnownBits Known;
2622
+ SDValue X = LHS->getOperand (0 );
2623
+ DAG.computeKnownBits (X, Known);
2624
+ unsigned LZ = Known.countMinLeadingZeros ();
2625
+ if (LZ < RHSVal)
2626
+ break ;
2627
+ EVT XVT = X.getValueType ();
2628
+ SDValue Shl = DAG.getNode (ISD::SHL, SL, XVT, X, SDValue (RHS, 0 ));
2629
+ return DAG.getZExtOrTrunc (Shl, SL, VT);
2630
+ }
2631
+ }
2632
+
2633
+ // i64 (shl x, C) -> (build_pair 0, (shl x, C -32))
2634
+
2635
+ // On some subtargets, 64-bit shift is a quarter rate instruction. In the
2636
+ // common case, splitting this into a move and a 32-bit shift is faster and
2637
+ // the same code size.
2638
+ if (RHSVal < 32 )
2639
+ return SDValue ();
2640
+
2619
2641
SDValue ShiftAmt = DAG.getConstant (RHSVal - 32 , SL, MVT::i32);
2620
2642
2621
2643
SDValue Lo = DAG.getNode (ISD::TRUNCATE, SL, MVT::i32, LHS);
0 commit comments