diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -593,7 +593,7 @@ SDValue MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); SDValue MatchLoadCombine(SDNode *N); SDValue mergeTruncStores(StoreSDNode *N); - SDValue ReduceLoadWidth(SDNode *N); + SDValue reduceLoadWidth(SDNode *N); SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); SDValue TransformFPLoadStorePair(SDNode *N); @@ -5624,7 +5624,7 @@ if (And.getOpcode() == ISD ::AND) And = SDValue( DAG.UpdateNodeOperands(And.getNode(), SDValue(Load, 0), MaskOp), 0); - SDValue NewLoad = ReduceLoadWidth(And.getNode()); + SDValue NewLoad = reduceLoadWidth(And.getNode()); assert(NewLoad && "Shouldn't be masking the load if it can't be narrowed"); CombineTo(Load, NewLoad, NewLoad.getValue(1)); @@ -6024,7 +6024,7 @@ if (!VT.isVector() && N1C && (N0.getOpcode() == ISD::LOAD || (N0.getOpcode() == ISD::ANY_EXTEND && N0.getOperand(0).getOpcode() == ISD::LOAD))) { - if (SDValue Res = ReduceLoadWidth(N)) { + if (SDValue Res = reduceLoadWidth(N)) { LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND ? cast(N0.getOperand(0)) : cast(N0); AddToWorklist(N); @@ -9140,7 +9140,7 @@ return NewSRL; // Attempt to convert a srl of a load into a narrower zero-extending load. - if (SDValue NarrowLoad = ReduceLoadWidth(N)) + if (SDValue NarrowLoad = reduceLoadWidth(N)) return NarrowLoad; // Here is a common situation. We want to optimize: @@ -11357,7 +11357,7 @@ if (N0.getOpcode() == ISD::TRUNCATE) { // fold (sext (truncate (load x))) -> (sext (smaller load x)) // fold (sext (truncate (srl (load x), c))) -> (sext (smaller load (x+c/n))) - if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { + if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) { SDNode *oye = N0.getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -11621,7 +11621,7 @@ if (N0.getOpcode() == ISD::TRUNCATE) { // fold (zext (truncate (load x))) -> (zext (smaller load x)) // fold (zext (truncate (srl (load x), c))) -> (zext (smaller load (x+c/n))) - if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { + if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) { SDNode *oye = N0.getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -11864,7 +11864,7 @@ // fold (aext (truncate (load x))) -> (aext (smaller load x)) // fold (aext (truncate (srl (load x), c))) -> (aext (small load (x+c/n))) if (N0.getOpcode() == ISD::TRUNCATE) { - if (SDValue NarrowLoad = ReduceLoadWidth(N0.getNode())) { + if (SDValue NarrowLoad = reduceLoadWidth(N0.getNode())) { SDNode *oye = N0.getOperand(0).getNode(); if (NarrowLoad.getNode() != N0.getNode()) { CombineTo(N0.getNode(), NarrowLoad); @@ -12095,13 +12095,10 @@ return SDValue(); } -/// If the result of a wider load is shifted to right of N bits and then -/// truncated to a narrower type and where N is a multiple of number of bits of -/// the narrower type, transform it to a narrower load from address + N / num of -/// bits of new type. Also narrow the load if the result is masked with an AND -/// to effectively produce a smaller type. If the result is to be extended, also -/// fold the extension to form a extending load. -SDValue DAGCombiner::ReduceLoadWidth(SDNode *N) { +/// If the result of a load is shifted/masked/truncated to an effectively +/// narrower type, try to transform the load to a narrower type and/or +/// use an extending load. +SDValue DAGCombiner::reduceLoadWidth(SDNode *N) { unsigned Opc = N->getOpcode(); ISD::LoadExtType ExtType = ISD::NON_EXTLOAD; @@ -12113,7 +12110,14 @@ if (VT.isVector()) return SDValue(); + // The ShAmt variable is used to indicate that we've consumed a right + // shift. I.e. we want to narrow the width of the load by skipping to load the + // ShAmt least significant bits. unsigned ShAmt = 0; + // A special case is when the least significant bits from the load are masked + // away, but using an AND rather than a right shift. HasShiftedOffset is used + // to indicate that the narrowed load should be left-shifted ShAmt bits to get + // the result. bool HasShiftedOffset = false; // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then // extended to VT. @@ -12122,23 +12126,29 @@ ExtVT = cast(N->getOperand(1))->getVT(); } else if (Opc == ISD::SRL) { // Another special-case: SRL is basically zero-extending a narrower value, - // or it maybe shifting a higher subword, half or byte into the lowest + // or it may be shifting a higher subword, half or byte into the lowest // bits. - ExtType = ISD::ZEXTLOAD; - N0 = SDValue(N, 0); - auto *LN0 = dyn_cast(N0.getOperand(0)); - auto *N01 = dyn_cast(N0.getOperand(1)); - if (!N01 || !LN0) + // Only handle shift with constant shift amount, and the shiftee must be a + // load. + auto *LN = dyn_cast(N0); + auto *N1C = dyn_cast(N->getOperand(1)); + if (!N1C || !LN) + return SDValue(); + // If the shift amount is larger than the memory type then we're not + // accessing any of the loaded bytes. + ShAmt = N1C->getZExtValue(); + uint64_t MemoryWidth = LN->getMemoryVT().getScalarSizeInBits(); + if (MemoryWidth <= ShAmt) + return SDValue(); + // Attempt to fold away the SRL by using ZEXTLOAD. + ExtType = ISD::ZEXTLOAD; + ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShAmt); + // If original load is a SEXTLOAD then we can't simply replace it by a + // ZEXTLOAD (we could potentially replace it by a more narrow SEXTLOAD + // followed by a ZEXT, but that is not handled at the moment). + if (LN->getExtensionType() == ISD::SEXTLOAD) return SDValue(); - - uint64_t ShiftAmt = N01->getZExtValue(); - uint64_t MemoryWidth = LN0->getMemoryVT().getScalarSizeInBits(); - if (LN0->getExtensionType() != ISD::SEXTLOAD && MemoryWidth > ShiftAmt) - ExtVT = EVT::getIntegerVT(*DAG.getContext(), MemoryWidth - ShiftAmt); - else - ExtVT = EVT::getIntegerVT(*DAG.getContext(), - VT.getScalarSizeInBits() - ShiftAmt); } else if (Opc == ISD::AND) { // An AND with a constant mask is the same as a truncate + zero-extend. auto AndC = dyn_cast(N->getOperand(1)); @@ -12161,55 +12171,73 @@ ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); } - if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { - SDValue SRL = N0; - if (auto *ConstShift = dyn_cast(SRL.getOperand(1))) { - ShAmt = ConstShift->getZExtValue(); - unsigned EVTBits = ExtVT.getScalarSizeInBits(); - // Is the shift amount a multiple of size of VT? - if ((ShAmt & (EVTBits-1)) == 0) { - N0 = N0.getOperand(0); - // Is the load width a multiple of size of VT? - if ((N0.getScalarValueSizeInBits() & (EVTBits - 1)) != 0) - return SDValue(); - } + // In case Opc==SRL we've already prepared ExtVT/ExtType/ShAmt based on doing + // a right shift. Here we redo some of those checks, to possibly adjust the + // ExtVT even further based on "a masking AND". We could also end up here for + // other reasons (e.g. based on Opc==TRUNCATE) and that is why some checks + // need to be done here as well. + if (Opc == ISD::SRL || N0.getOpcode() == ISD::SRL) { + SDValue SRL = Opc == ISD::SRL ? SDValue(N, 0) : N0; + // Bail out when the SRL has more than one use. This is done for historical + // (undocumented) reasons. Maybe intent was to guard the AND-masking below + // check below? And maybe it could be non-profitable to do the transform in + // case the SRL has multiple uses and we get here with Opc!=ISD::SRL? + // FIXME: Can't we just skip this check for the Opc==ISD::SRL case. + if (!SRL.hasOneUse()) + return SDValue(); + + // Only handle shift with constant shift amount, and the shiftee must be a + // load. + auto *LN = dyn_cast(SRL.getOperand(0)); + auto *SRL1C = dyn_cast(SRL.getOperand(1)); + if (!SRL1C || !LN) + return SDValue(); - // At this point, we must have a load or else we can't do the transform. - auto *LN0 = dyn_cast(N0); - if (!LN0) return SDValue(); + // If the shift amount is larger than the input type then we're not + // accessing any of the loaded bytes. If the load was a zextload/extload + // then the result of the shift+trunc is zero/undef (handled elsewhere). + ShAmt = SRL1C->getZExtValue(); + if (ShAmt >= LN->getMemoryVT().getSizeInBits()) + return SDValue(); - // Because a SRL must be assumed to *need* to zero-extend the high bits - // (as opposed to anyext the high bits), we can't combine the zextload - // lowering of SRL and an sextload. - if (LN0->getExtensionType() == ISD::SEXTLOAD) - return SDValue(); + // Because a SRL must be assumed to *need* to zero-extend the high bits + // (as opposed to anyext the high bits), we can't combine the zextload + // lowering of SRL and an sextload. + if (LN->getExtensionType() == ISD::SEXTLOAD) + return SDValue(); - // If the shift amount is larger than the input type then we're not - // accessing any of the loaded bytes. If the load was a zextload/extload - // then the result of the shift+trunc is zero/undef (handled elsewhere). - if (ShAmt >= LN0->getMemoryVT().getSizeInBits()) - return SDValue(); + unsigned ExtVTBits = ExtVT.getScalarSizeInBits(); + // Is the shift amount a multiple of size of ExtVT? + if ((ShAmt & (ExtVTBits - 1)) != 0) + return SDValue(); + // Is the load width a multiple of size of ExtVT? + if ((SRL.getScalarValueSizeInBits() & (ExtVTBits - 1)) != 0) + return SDValue(); - // If the SRL is only used by a masking AND, we may be able to adjust - // the ExtVT to make the AND redundant. - SDNode *Mask = *(SRL->use_begin()); - if (Mask->getOpcode() == ISD::AND && - isa(Mask->getOperand(1))) { - const APInt& ShiftMask = Mask->getConstantOperandAPInt(1); - if (ShiftMask.isMask()) { - EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), - ShiftMask.countTrailingOnes()); - // If the mask is smaller, recompute the type. - if ((ExtVT.getScalarSizeInBits() > MaskedVT.getScalarSizeInBits()) && - TLI.isLoadExtLegal(ExtType, N0.getValueType(), MaskedVT)) - ExtVT = MaskedVT; - } + // If the SRL is only used by a masking AND, we may be able to adjust + // the ExtVT to make the AND redundant. + SDNode *Mask = *(SRL->use_begin()); + if (SRL.hasOneUse() && Mask->getOpcode() == ISD::AND && + isa(Mask->getOperand(1))) { + const APInt& ShiftMask = Mask->getConstantOperandAPInt(1); + if (ShiftMask.isMask()) { + EVT MaskedVT = EVT::getIntegerVT(*DAG.getContext(), + ShiftMask.countTrailingOnes()); + // If the mask is smaller, recompute the type. + if ((ExtVTBits > MaskedVT.getScalarSizeInBits()) && + TLI.isLoadExtLegal(ExtType, SRL.getValueType(), MaskedVT)) + ExtVT = MaskedVT; } } + + N0 = SRL.getOperand(0); } - // If the load is shifted left (and the result isn't shifted back right), - // we can fold the truncate through the shift. + // If the load is shifted left (and the result isn't shifted back right), we + // can fold a truncate through the shift. The typical scenario is that N + // points at a TRUNCATE here so the attempted fold is: + // (truncate (shl (load x), c))) -> (shl (narrow load x), c) + // ShLeftAmt will indicate how much a narrowed load should be shifted left. unsigned ShLeftAmt = 0; if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() && ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) { @@ -12237,12 +12265,12 @@ return LVTStoreBits - EVTStoreBits - ShAmt; }; - // For big endian targets, we need to adjust the offset to the pointer to - // load the correct bytes. - if (DAG.getDataLayout().isBigEndian()) - ShAmt = AdjustBigEndianShift(ShAmt); + // We need to adjust the pointer to the load by ShAmt bits in order to load + // the correct bytes. + unsigned PtrAdjustmentInBits = + DAG.getDataLayout().isBigEndian() ? AdjustBigEndianShift(ShAmt) : ShAmt; - uint64_t PtrOff = ShAmt / 8; + uint64_t PtrOff = PtrAdjustmentInBits / 8; Align NewAlign = commonAlignment(LN0->getAlign(), PtrOff); SDLoc DL(LN0); // The original load itself didn't wrap, so an offset within it doesn't. @@ -12285,11 +12313,6 @@ } if (HasShiftedOffset) { - // Recalculate the shift amount after it has been altered to calculate - // the offset. - if (DAG.getDataLayout().isBigEndian()) - ShAmt = AdjustBigEndianShift(ShAmt); - // We're using a shifted mask, so the load now has an offset. This means // that data has been loaded into the lower bytes than it would have been // before, so we need to shl the loaded data into the correct position in the @@ -12382,7 +12405,7 @@ // fold (sext_in_reg (load x)) -> (smaller sextload x) // fold (sext_in_reg (srl (load x), c)) -> (smaller sextload (x+c/evtbits)) - if (SDValue NarrowLoad = ReduceLoadWidth(N)) + if (SDValue NarrowLoad = reduceLoadWidth(N)) return NarrowLoad; // fold (sext_in_reg (srl X, 24), i8) -> (sra X, 24) @@ -12669,7 +12692,7 @@ // fold (truncate (load x)) -> (smaller load x) // fold (truncate (srl (load x), c)) -> (smaller load (x+c/evtbits)) if (!LegalTypes || TLI.isTypeDesirableForOp(N0.getOpcode(), VT)) { - if (SDValue Reduced = ReduceLoadWidth(N)) + if (SDValue Reduced = reduceLoadWidth(N)) return Reduced; // Handle the case where the load remains an extending load even