Index: include/llvm/CodeGen/TargetLowering.h =================================================================== --- include/llvm/CodeGen/TargetLowering.h +++ include/llvm/CodeGen/TargetLowering.h @@ -2947,7 +2947,8 @@ // This transformation may not be desirable if it disrupts a particularly // auspicious target-specific tree (e.g. bitfield extraction in AArch64). // By default, it returns true. - virtual bool isDesirableToCommuteWithShift(const SDNode *N) const { + virtual bool isDesirableToCommuteWithShift(const SDNode *N, + CombineLevel Level) const { return true; } Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -6206,7 +6206,7 @@ return SDValue(); } - if (!TLI.isDesirableToCommuteWithShift(LHS)) + if (!TLI.isDesirableToCommuteWithShift(N, Level)) return SDValue(); // Fold the constants, shifting the binop RHS by the shift amount. @@ -6510,7 +6510,8 @@ if ((N0.getOpcode() == ISD::ADD || N0.getOpcode() == ISD::OR) && N0.getNode()->hasOneUse() && isConstantOrConstantVector(N1, /* No Opaques */ true) && - isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true)) { + isConstantOrConstantVector(N0.getOperand(1), /* No Opaques */ true) && + TLI.isDesirableToCommuteWithShift(N, Level)) { SDValue Shl0 = DAG.getNode(ISD::SHL, SDLoc(N0), VT, N0.getOperand(0), N1); SDValue Shl1 = DAG.getNode(ISD::SHL, SDLoc(N1), VT, N0.getOperand(1), N1); AddToWorklist(Shl0.getNode()); Index: lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.h +++ lib/Target/AArch64/AArch64ISelLowering.h @@ -363,7 +363,8 @@ const MCPhysReg *getScratchRegisters(CallingConv::ID CC) const override; /// Returns false if N is a bit extraction pattern of (X >> C) & Mask. - bool isDesirableToCommuteWithShift(const SDNode *N) const override; + bool isDesirableToCommuteWithShift(const SDNode *N, + CombineLevel Level) const override; /// Returns true if it is beneficial to convert a load of a constant /// to just the constant itself. Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -8473,7 +8473,9 @@ } bool -AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N) const { +AArch64TargetLowering::isDesirableToCommuteWithShift(const SDNode *N, + CombineLevel Level) const { + N = N->getOperand(0).getNode(); EVT VT = N->getValueType(0); // If N is unsigned bit extraction: ((x >> C) & mask), then do not combine // it with shift to let it be lowered to UBFX. Index: lib/Target/ARM/ARMISelLowering.h =================================================================== --- lib/Target/ARM/ARMISelLowering.h +++ lib/Target/ARM/ARMISelLowering.h @@ -586,6 +586,9 @@ unsigned getABIAlignmentForCallingConv(Type *ArgTy, DataLayout DL) const override; + bool isDesirableToCommuteWithShift(const SDNode *N, + CombineLevel Level) const override; + protected: std::pair findRepresentativeClass(const TargetRegisterInfo *TRI, Index: lib/Target/ARM/ARMISelLowering.cpp =================================================================== --- lib/Target/ARM/ARMISelLowering.cpp +++ lib/Target/ARM/ARMISelLowering.cpp @@ -10447,6 +10447,25 @@ return SDValue(); } +bool +ARMTargetLowering::isDesirableToCommuteWithShift(const SDNode *N, + CombineLevel Level) const { + if (Level == BeforeLegalizeTypes) + return true; + + if (Subtarget->isThumb() && Subtarget->isThumb1Only()) + return true; + + if (N->getOpcode() != ISD::SHL) + return true; + + // Turn off commute-with-shift transform after legalization, so it doesn't + // conflict with PerformSHLSimplify. (We could try to detect when + // PerformSHLSimplify would trigger more precisely, but it isn't + // really necessary.) + return false; +} + static SDValue PerformSHLSimplify(SDNode *N, TargetLowering::DAGCombinerInfo &DCI, const ARMSubtarget *ST) { @@ -10546,9 +10565,7 @@ LLVM_DEBUG(dbgs() << "Simplify shl use:\n"; SHL.getOperand(0).dump(); SHL.dump(); N->dump()); LLVM_DEBUG(dbgs() << "Into:\n"; X.dump(); BinOp.dump(); Res.dump()); - - DAG.ReplaceAllUsesWith(SDValue(N, 0), Res); - return SDValue(N, 0); + return Res; }