Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -447,6 +447,7 @@ SDNode *MatchRotate(SDValue LHS, SDValue RHS, const SDLoc &DL); SDValue MatchLoadCombine(SDNode *N); SDValue ReduceLoadWidth(SDNode *N); + SDValue foldRedundantShiftedMasks(SDNode *N); SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); SDValue TransformFPLoadStorePair(SDNode *N); @@ -4068,6 +4069,109 @@ return false; } +// fold expressions x1 and x2 alike: +// x1 = ( and, x, 0x00FF ) +// x2 = (( shl x, 8 ) and 0xFF00 ) +// into +// x2 = shl x1, 8 ; reuse the computation of x1 +SDValue DAGCombiner::foldRedundantShiftedMasks(SDNode *AND) { + const SDValue &SHIFT = AND->getOperand(0); + if ((SHIFT.getNumOperands() != 2) || (!SHIFT.hasOneUse())) + return SDValue(); + + const ConstantSDNode *ShiftAmount = + dyn_cast(SHIFT.getOperand(1)); + if (!ShiftAmount) + return SDValue(); + + const ConstantSDNode *Mask = dyn_cast(AND->getOperand(1)); + if (!Mask) + return SDValue(); + + SDValue MASKED = SHIFT.getOperand(0); + const auto &MaskedValue = dyn_cast(MASKED); + unsigned N0Opcode = SHIFT.getOpcode(); + for (SDNode *OtherUser : MaskedValue->uses()) { + if ((&(*OtherUser) == ShiftAmount) || (OtherUser->getOpcode() != ISD::AND)) + continue; + + ConstantSDNode *OtherMask = + dyn_cast(OtherUser->getOperand(1)); + + if (!OtherMask) + continue; + + bool CanReduce = false; + + const APInt &MaskValue = Mask->getAPIntValue(); + const APInt &ShiftValue = ShiftAmount->getAPIntValue(); + const APInt &OtherMaskValue = OtherMask->getAPIntValue(); + + KnownBits MaskedValueBits; + DAG.computeKnownBits(MASKED, MaskedValueBits); + KnownBits ShiftedValueBits; + DAG.computeKnownBits(SHIFT, ShiftedValueBits); + + const APInt EffectiveOtherMask = OtherMaskValue & ~MaskedValueBits.Zero; + const APInt EffectiveMask = MaskValue & ~ShiftedValueBits.Zero; + + LLVM_DEBUG( + dbgs() << "\tMasked value: "; MASKED.dump(); + dbgs() << "\t\tMasked value zero bits: 0x" + << MaskedValueBits.Zero.toString(16, false) + << "\n\n\t\tApplied mask: 0x" + << OtherMaskValue.toString(16, false) << " : "; + OtherUser->dump(); + dbgs() << "\t\tEffective mask: 0x" + << EffectiveOtherMask.toString(16, false) + << "\n\n\tShifted by: " << ShiftValue.getZExtValue() << " : "; + SHIFT.dump(); dbgs() << "\t\tAnd masked by: 0x" + << MaskValue.toString(16, false) << " : "; + AND->dump(); dbgs() << "\t\tEffective mask to shifted value: 0x" + << EffectiveMask.toString(16, false) << '\n';); + + switch (N0Opcode) { + case ISD::SHL: + CanReduce = (EffectiveOtherMask.shl(EffectiveMask) == EffectiveMask) || + (EffectiveMask.lshr(ShiftValue) == EffectiveOtherMask); + break; + case ISD::SRA: + if (!MaskedValueBits.Zero.isSignBitSet()) { + CanReduce = (EffectiveOtherMask.ashr(ShiftValue) == EffectiveMask); + break; + } else // Same as SRL + N0Opcode = ISD::SRL; + /* fall-thru */ + case ISD::SRL: + CanReduce = (EffectiveOtherMask.lshr(ShiftValue) == EffectiveMask) || + (EffectiveMask.shl(ShiftValue) == EffectiveOtherMask); + break; + case ISD::ROTL: + CanReduce = (EffectiveOtherMask.rotl(ShiftValue) == EffectiveMask); + break; + case ISD::ROTR: + CanReduce = (EffectiveOtherMask.rotr(ShiftValue) == EffectiveMask); + break; + // TODO: Add SHL_PARTS SRA_PARTS SRL_PARTS + break; + default: + return SDValue(); + } + if (CanReduce) { + LLVM_DEBUG(dbgs() << "\tCan replace it\n"); + + SDValue ShiftTheAND(OtherUser, 0); + const SDLoc DL(SHIFT); + EVT VT = AND->getValueType(0); + SDValue NewShift = + DAG.getNode(N0Opcode, DL, VT, ShiftTheAND, SHIFT.getOperand(1)); + AddToWorklist(OtherUser); + return NewShift; + } + } + return SDValue(); +} + SDValue DAGCombiner::visitAND(SDNode *N) { SDValue N0 = N->getOperand(0); SDValue N1 = N->getOperand(1); @@ -4268,6 +4372,9 @@ (N0.getOpcode() == ISD::ANY_EXTEND && N0.getOperand(0).getOpcode() == ISD::LOAD))) { if (SDValue Res = ReduceLoadWidth(N)) { + if (Res.getOpcode() == ISD::SHL) + return Res; + LoadSDNode *LN0 = N0->getOpcode() == ISD::ANY_EXTEND ? cast(N0.getOperand(0)) : cast(N0); @@ -4277,6 +4384,9 @@ } } + if (SDValue r = foldRedundantShiftedMasks(N)) + return r; + if (Level >= AfterLegalizeTypes) { // Attempt to propagate the AND back up to the leaves which, if they're // loads, can be combined to narrow loads and the AND node can be removed. @@ -6332,13 +6442,36 @@ } // fold (srl (shl x, c), c) -> (and x, cst2) - if (N0.getOpcode() == ISD::SHL && N0.getOperand(1) == N1 && - isConstantOrConstantVector(N1, /* NoOpaques */ true)) { - SDLoc DL(N); - SDValue Mask = - DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1); - AddToWorklist(Mask.getNode()); - return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask); + if ((N0.getOpcode() == ISD::SHL) && + (isConstantOrConstantVector(N1, /* NoOpaques */ true))) { + bool CanFold = N0.getOperand(1) == N1; + if (!CanFold) { + const ConstantSDNode *CN0N1 = dyn_cast(N0.getOperand(1)); + if (CN0N1 && N1C) + CanFold = CN0N1->getZExtValue() == N1C->getZExtValue(); + } + + if (CanFold) { + // fold (srl (shl X, c), c) -> (c) if the X upper bits of c are known to + // be 0 + // TODO: Add more instructions that produce known upper bits zero masks, + // other than zext loads + if (N1C) { + if (LoadSDNode *X = dyn_cast(N0.getOperand(0))) { + const unsigned XSize = X->getValueSizeInBits(0); + const unsigned XMemSize = X->getMemOperand()->getSize() * 8; + if ((XSize > XMemSize) && + ((XSize - XMemSize) >= N1C->getZExtValue()) && + (X->getExtensionType() == ISD::LoadExtType::ZEXTLOAD)) + return N0.getOperand(0); + } + } + SDLoc DL(N); + SDValue Mask = + DAG.getNode(ISD::SRL, DL, VT, DAG.getAllOnesConstant(DL, VT), N1); + AddToWorklist(Mask.getNode()); + return DAG.getNode(ISD::AND, DL, VT, N0.getOperand(0), Mask); + } } // fold (srl (anyextend x), c) -> (and (anyextend (srl x, c)), mask) @@ -8540,6 +8673,9 @@ if (VT.isVector()) return SDValue(); + unsigned ShAmt = 0; + unsigned ShLeftAmt = 0; + // Special case: SIGN_EXTEND_INREG is basically truncating to ExtVT then // extended to VT. if (Opc == ISD::SIGN_EXTEND_INREG) { @@ -8567,15 +8703,65 @@ } else if (Opc == ISD::AND) { // An AND with a constant mask is the same as a truncate + zero-extend. auto AndC = dyn_cast(N->getOperand(1)); - if (!AndC || !AndC->getAPIntValue().isMask()) + if (!AndC) return SDValue(); - unsigned ActiveBits = AndC->getAPIntValue().countTrailingOnes(); + // TODO: Not only [shifted] masks should be accepted. + //(and ld.16 [M], 0x00AB) can be replaced by (and ld.8.zext16 [M], 0x00AB). + const APInt &MaskAPInt = AndC->getAPIntValue(); + if (!(MaskAPInt.isMask() || MaskAPInt.isShiftedMask())) + return SDValue(); + + unsigned MaxBit = MaskAPInt.getBitWidth() - MaskAPInt.countLeadingZeros(); + const unsigned MinBit = MaskAPInt.countTrailingZeros(); + // Only accepts multiples of 8 bits, and power of 2 sizes + if (!MaxBit && (0 != (MaxBit | MinBit) % 8)) + return SDValue(); + + unsigned ActiveBits = MaxBit - MinBit; + if (ActiveBits & (ActiveBits - 1)) + return SDValue(); + + LLVM_DEBUG(dbgs() << "\tMask: 0x" << MaskAPInt.toString(16, false) << " : "; + AndC->dump(); + dbgs() << "\t\tmaxActiveBit: " << MaxBit - 1 + << "\n\t\tminActiveBit: " << MinBit << '\n'); + + LoadSDNode *LN0 = dyn_cast(N0); ExtType = ISD::ZEXTLOAD; + if (MinBit != 0) { + // How to treat if it was not a load? + if (LN0 == nullptr) + return SDValue(); + + const auto &mvt = LN0->getMemoryVT(); + if (MinBit >= mvt.getSizeInBits()) { + // The (and) is filtering what was extended, not the actual data + // value... + if (ISD::LoadExtType::ZEXTLOAD == LN0->getExtensionType()) { + // We only read the zero values + return DAG.getConstant(0, SDLoc(N), AndC->getValueType(0)); + } + // We access the sign extension, not known here + return SDValue(); + } + if (MaxBit > mvt.getSizeInBits()) + ExtType = LN0->getExtensionType(); + } + // TODO: Accept SEXT if the architecture accepts doing a LD?SH (load + shl) + // An (and (ld.32bit.sext.from16 [M]), 0x00FFFF00) can be replaced by + // (and (shl (ld.32bit.sext.from8 [M+1]), 8), 0x00FFFF00) + if (ExtType != ISD::ZEXTLOAD) + return SDValue(); + ExtVT = EVT::getIntegerVT(*DAG.getContext(), ActiveBits); + ShAmt = MinBit; + ShLeftAmt = MinBit; + LLVM_DEBUG(dbgs() << "\tCan replace load: "; LN0->dump(); + dbgs() << "\tBy a load of width " << ActiveBits + << " and with offset of " << ShAmt / 8 << '\n'); } - unsigned ShAmt = 0; if (N0.getOpcode() == ISD::SRL && N0.hasOneUse()) { SDValue SRL = N0; if (auto *ConstShift = dyn_cast(SRL.getOperand(1))) { @@ -8626,7 +8812,6 @@ // If the load is shifted left (and the result isn't shifted back right), // we can fold the truncate through the shift. - unsigned ShLeftAmt = 0; if (ShAmt == 0 && N0.getOpcode() == ISD::SHL && N0.hasOneUse() && ExtVT == VT && TLI.isNarrowingProfitable(N0.getValueType(), VT)) { if (ConstantSDNode *N01 = dyn_cast(N0.getOperand(1))) {