Index: lib/CodeGen/SelectionDAG/DAGCombiner.cpp =================================================================== --- lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -448,6 +448,7 @@ SDValue MatchLoadCombine(SDNode *N); SDValue ReduceLoadWidth(SDNode *N); SDValue foldRedundantShiftedMasks(SDNode *N); + SDValue checkOverwritenLoadReduceStoreWidth(StoreSDNode *Store); SDValue ReduceLoadOpStoreWidth(SDNode *N); SDValue splitMergedValStore(StoreSDNode *ST); SDValue TransformFPLoadStorePair(SDNode *N); @@ -12972,9 +12973,12 @@ /// Check to see if V is (and load (ptr), imm), where the load is having /// specific bytes cleared out. If so, return the byte size being masked out /// and the shift amount. -static std::pair +typedef struct { + unsigned Bytes = 0, Offset = 0, Shl = 0; +} MaskedLoadValue; +static MaskedLoadValue CheckForMaskedLoad(SDValue V, SDValue Ptr, SDValue Chain) { - std::pair Result(0, 0); + MaskedLoadValue Result; // Check for the structure we're looking for. if (V->getOpcode() != ISD::AND || @@ -13038,8 +13042,9 @@ // is aligned the same as the access width. if (NotMaskTZ && NotMaskTZ/8 % MaskedBytes) return Result; - Result.first = MaskedBytes; - Result.second = NotMaskTZ/8; + Result.Bytes = MaskedBytes; + Result.Offset = NotMaskTZ/8; + Result.Shl = Result.Offset; return Result; } @@ -13047,32 +13052,30 @@ /// MaskInfo. If so, replace the specified store with a narrower store of /// truncated IVal. static SDNode * -ShrinkLoadReplaceStoreWithStore(const std::pair &MaskInfo, +ShrinkLoadReplaceStoreWithStore(MaskedLoadValue &MaskInfo, SDValue IVal, StoreSDNode *St, DAGCombiner *DC) { - unsigned NumBytes = MaskInfo.first; - unsigned ByteShift = MaskInfo.second; SelectionDAG &DAG = DC->getDAG(); // Check to see if IVal is all zeros in the part being masked in by the 'or' // that uses this. If not, this is not a replacement. APInt Mask = ~APInt::getBitsSet(IVal.getValueSizeInBits(), - ByteShift*8, (ByteShift+NumBytes)*8); + MaskInfo.Shl*8, (MaskInfo.Shl+MaskInfo.Bytes)*8); if (!DAG.MaskedValueIsZero(IVal, Mask)) return nullptr; // Check that it is legal on the target to do this. It is legal if the new // VT we're shrinking to (i8/i16/i32) is legal or we're still before type // legalization. - MVT VT = MVT::getIntegerVT(NumBytes*8); + MVT VT = MVT::getIntegerVT(MaskInfo.Bytes*8); if (!DC->isTypeLegal(VT)) return nullptr; // Okay, we can do this! Replace the 'St' store with a store of IVal that is - // shifted by ByteShift and truncated down to NumBytes. - if (ByteShift) { + // shifted by MaskInfo.Shl and truncated down to MaskInfo.Bytes. + if (MaskInfo.Shl) { SDLoc DL(IVal); IVal = DAG.getNode(ISD::SRL, DL, IVal.getValueType(), IVal, - DAG.getConstant(ByteShift*8, DL, + DAG.getConstant(MaskInfo.Shl*8, DL, DC->getShiftAmountTy(IVal.getValueType()))); } @@ -13081,9 +13084,9 @@ unsigned NewAlign = St->getAlignment(); if (DAG.getDataLayout().isLittleEndian()) - StOffset = ByteShift; + StOffset = MaskInfo.Offset; else - StOffset = IVal.getValueType().getStoreSize() - ByteShift - NumBytes; + StOffset = IVal.getValueType().getStoreSize() - MaskInfo.Offset - MaskInfo.Bytes; SDValue Ptr = St->getBasePtr(); if (StOffset) { @@ -13103,6 +13106,126 @@ .getNode(); } +/// Detects that half of a Store operation is just overwriting a previous +/// Load operation, using OR and SHIFT operations. Such as: +/// st i16 ((zext(ld i8, [M]) to i16) or (shl(zext(ld i8, [M])to i16), 8)), [M] +SDValue +DAGCombiner::checkOverwritenLoadReduceStoreWidth(StoreSDNode *ST){ + // It is safe to reduce the ST operation even after LegalizeTypes + // as we only create a ST of the same width of the LD + assert(ST && "Expecting non-null"); + if(!ST) + return SDValue(); + + if (ST->isVolatile()) + return SDValue(); + + SDValue OR = ST->getValue(); + if (!OR.hasOneUse()) + return SDValue(); + + if (OR.getOpcode() != ISD::OR) + return SDValue(); + + EVT VT = OR.getValueType(); + + if (VT.isVector()) + return SDValue(); + + SDValue LDVal = OR.getOperand(0); + LoadSDNode *LD = dyn_cast(LDVal); + SDNode *SHL; + if(!LD){ + SHL = LDVal.getNode(); + LDVal = OR.getOperand(1); + LD = dyn_cast(LDVal); + if(!LD) + return SDValue(); + } else + SHL = OR.getOperand(1).getNode(); + + if ((SHL->getOpcode() != ISD::SHL) || + (SHL->getOperand(0).getNode() != LD) || + !isa(SHL->getOperand(1))) + return SDValue(); + + SDValue Chain = ST->getChain(); + + if(Chain != SDValue(LD, 1)) + return SDValue(); + + unsigned BytesShift = cast(SHL->getOperand(1).getNode()) + ->getAPIntValue().getSExtValue() / 8; + + unsigned StoreMemSz = ST->getMemoryVT().getStoreSize(); + + if (2 * BytesShift != StoreMemSz) + return SDValue(); + + const SDValue LoadPtr = LD->getBasePtr(); + SDValue Ptr = ST->getBasePtr(); + unsigned LoadMemSz = LD->getMemoryVT().getStoreSize(); + if (BytesShift) + // TODO: Detect when both LOAD and STORE memory addresses are both ADD + // instructions to a common base address, with a known constant difference + // Ex: load i8 [M+3] and store i16 [M+2] + if (LoadPtr == Ptr) { + if ((LoadMemSz < StoreMemSz) && LD->getExtensionType() != ISD::ZEXTLOAD) + return SDValue(); + + if (LoadMemSz == BytesShift) { + // replace something like + // store i16( or(ld i8 [M] zext i16), (shl ( ld i8 [M] ), 8) ),[M] + // by + // store i8( ld i8 [M]) [M+1] +// LLVM_DEBUG(dbgs() << "\tGot load: "; LD->dump()); +// LLVM_DEBUG(dbgs() << "\tThat is shifted by: "; SHL->dump()); +// LLVM_DEBUG(dbgs() << "\tOR combined by: "; OR->dump()); +// LLVM_DEBUG(dbgs() << "\tAnd stored by: "; ST->dump()); +// LLVM_DEBUG(dbgs() << "\tNot writing the lower half of the value\n"); + // For big endian targets, we need to adjust the offset to the pointer to + // load the correct bytes. + MaskedLoadValue MLV = {LoadMemSz, LoadMemSz, 0}; + return SDValue(ShrinkLoadReplaceStoreWithStore(MLV, LDVal, ST, this), 0); + } + return SDValue(); + } + + SDValue OffsetPtr = LoadPtr; + SDValue NonOffsetPtr = Ptr; + + if(DAG.getDataLayout().isBigEndian()) + std::swap(OffsetPtr, NonOffsetPtr); + + if ((OffsetPtr.getOpcode() != ISD::ADD)|| + (!((OffsetPtr.getOperand(0) == NonOffsetPtr) || + (OffsetPtr.getOperand(1) == NonOffsetPtr)))) + return SDValue(); + + ConstantSDNode *Offset = dyn_cast(OffsetPtr.getOperand(1)); + if (!Offset) + Offset = dyn_cast(OffsetPtr.getOperand(0)); + + if (!Offset) + return SDValue(); + + unsigned LoadByteOffset = Offset->getAPIntValue().getZExtValue(); + if (!((LoadByteOffset == LoadMemSz) && (2 * LoadMemSz == StoreMemSz) && + LD->getExtensionType() == ISD::ZEXTLOAD)) + return SDValue(); + + // Replace something like + // store i16( or(ld i8 [M+1]), (shl ( ld i8 [M+1]), 8) ),[M] + // by + // store i8( ld i8 [M+1]) [M]. The ld must be zext. +// LLVM_DEBUG(dbgs() << "\tGot load: "; LD->dump()); +// LLVM_DEBUG(dbgs() << "\tThat is shifted by: "; SHL->dump()); +// LLVM_DEBUG(dbgs() << "\tOR combined by: "; OR->dump()); +// LLVM_DEBUG(dbgs() << "\tAnd stored by: "; ST->dump()); +// LLVM_DEBUG(dbgs() << "\tNot writing the upper half of the value\n"); + MaskedLoadValue MLV = {LoadMemSz, 0, 0}; + return SDValue(ShrinkLoadReplaceStoreWithStore(MLV, LDVal, ST, this), 0); +} /// Look for sequence of load / op / store where op is one of 'or', 'xor', and /// 'and' of immediates. If 'op' is only touching some of the loaded bits, try /// narrowing the load and store if it would end up being a win for performance @@ -13128,16 +13251,16 @@ // load + replace + store sequence with a single (narrower) store, which makes // the load dead. if (Opc == ISD::OR) { - std::pair MaskedLoad; + MaskedLoadValue MaskedLoad; MaskedLoad = CheckForMaskedLoad(Value.getOperand(0), Ptr, Chain); - if (MaskedLoad.first) + if (MaskedLoad.Bytes) if (SDNode *NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad, Value.getOperand(1), ST,this)) return SDValue(NewST, 0); // Or is commutative, so try swapping X and Y. MaskedLoad = CheckForMaskedLoad(Value.getOperand(1), Ptr, Chain); - if (MaskedLoad.first) + if (MaskedLoad.Bytes) if (SDNode *NewST = ShrinkLoadReplaceStoreWithStore(MaskedLoad, Value.getOperand(0), ST,this)) return SDValue(NewST, 0); @@ -14519,6 +14642,9 @@ if (SDValue NewSt = splitMergedValStore(ST)) return NewSt; + if (SDValue NewSt = checkOverwritenLoadReduceStoreWidth(ST)) + return NewSt; + return ReduceLoadOpStoreWidth(N); } Index: test/CodeGen/ARM/2018_05_30_FoldMakedMoves.ll =================================================================== --- /dev/null +++ test/CodeGen/ARM/2018_05_30_FoldMakedMoves.ll @@ -0,0 +1,121 @@ +; RUN: llc -O3 -march=arm %s -o - | FileCheck %s -check-prefix=ARM +; RUN: llc -O3 -march=armeb %s -o - | FileCheck %s -check-prefix=ARMEB +target datalayout = "e-m:e-p:32:32-i64:64-v128:64:128-a:0:32-n32-S64" +target triple = "armv4t-arm-none-eabi" +define void @foo1(i16* %b){ +entry: + %0 = load i16, i16* %b, align 2 + %conv = sext i16 %0 to i32 + %and = and i32 %conv, 65280 + %1 = lshr i32 %conv, 8 + %and3 = and i32 %1, 255 + %or = or i32 %and3, %and + %conv4 = trunc i32 %or to i16 + store i16 %conv4, i16* %b, align 2 + ret void +} + +define void @foo2(i32* %b){ +entry: + %0 = load i32, i32* %b, align 4 + %1 = lshr i32 %0, 16 + %and2 = and i32 %1, 65535 + %and = and i32 %0, 4294901760 + %or = or i32 %and2, %and + store i32 %or, i32* %b, align 4 + ret void +} + +define void @test_1x4p1(i32* %M, i32 %I) { +entry: + %0 = getelementptr inbounds i32, i32* %M, i32 %I + %1 = load i32, i32* %0, align 4 + %2 = and i32 %1, 65280 + %3 = lshr i32 %1, 8 + %4 = and i32 %3, 255 + %5 = or i32 %2, %4 + store i32 %5, i32* %0, align 4 + ret void +} + +define void @test_1x4p1_shl(i32* %M, i32 %I) { +entry: + %0 = getelementptr inbounds i32, i32* %M, i32 %I + %1 = load i32, i32* %0, align 4 + %2 = and i32 %1, 65280 + %3 = shl i32 %1, 8 + %4 = and i32 %3, 16711680 + %5 = or i32 %2, %4 + store i32 %5, i32* %0, align 4 + ret void +} +; ARMEB-LABEL:foo1: +; ARMEB: ldrb [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]]{{\]}} +; ARMEB: strb [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], #2] +; ARMEB-LABEL:foo2: +; ARMEB: ldrh [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]]{{\]}} +; ARMEB: strh [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], #2] +; ARMEB-LABEL:test_1x4p1: +; ARMEB: add [[R2:r[0-9]+]], [[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2 +; ARMEB: ldrb [[R2:r[0-9]+]], {{\[}}[[R2:r[0-9]+]], #2] +; ARMEB: orr [[R2:r[0-9]+]], [[R2:r[0-9]+]], [[R2:r[0-9]+]], lsl #8 +; ARMEB: str [[R2:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2] +; ARMEB-LABEL:test_1x4p1_shl: +; ARMEB: add [[R2:r[0-9]+]], [[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2 +; ARMEB: ldrb [[R2:r[0-9]+]], {{\[}}[[R2:r[0-9]+]], #2] +; ARMEB: lsl [[R3:r[0-9]+]], [[R2:r[0-9]+]], #8 +; ARMEB: orr [[R2:r[0-9]+]], [[R3:r[0-9]+]], [[R2:r[0-9]+]], lsl #16 +; ARMEB: str [[R2:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2] +; ARM-LABEL: foo1: +; ARM: ldrb [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], #1] +; ARM: strb [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]]{{\]}} +; ARM-LABEL: foo2: +; ARM: ldrh [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], #2] +; ARM: strh [[R1:r[0-9]+]], {{\[}}[[R0:r[0-9]+]]{{\]}} +; ARM-LABEL: test_1x4p1: +; ARM: add [[R2:r[0-9]+]], [[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2 +; ARM: ldrb [[R2:r[0-9]+]], {{\[}}[[R2:r[0-9]+]], #1] +; ARM: orr [[R2:r[0-9]+]], [[R2:r[0-9]+]], [[R2:r[0-9]+]], lsl #8 +; ARM: str [[R2:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2] +; ARM-LABEL: test_1x4p1_shl: +; ARM: add [[R2:r[0-9]+]], [[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2 +; ARM: ldrb [[R2:r[0-9]+]], {{\[}}[[R2:r[0-9]+]], #1] +; ARM: lsl [[R3:r[0-9]+]], [[R2:r[0-9]+]], #8 +; ARM: orr [[R2:r[0-9]+]], [[R3:r[0-9]+]], [[R2:r[0-9]+]], lsl #16 +; ARM: str [[R2:r[0-9]+]], {{\[}}[[R0:r[0-9]+]], [[R1:r[0-9]+]], lsl #2] +; armeb +;foo1: +; ldrb r1, [r0] +; strb r1, [r0, #2] +;foo2: +; ldrh r1, [r0] +; strh r1, [r0, #2] +;test_1x4p1: +; add r2, r0, r1, lsl #2 +; ldrb r2, [r2, #2] +; orr r2, r2, r2, lsl #8 +; str r2, [r0, r1, lsl #2] +;test_1x4p1_shl: +; add r2, r0, r1, lsl #2 +; ldrb r2, [r2, #2] +; lsl r3, r2, #8 +; orr r2, r3, r2, lsl #16 +; str r2, [r0, r1, lsl #2] +; arm +;foo1: +; ldrb r1, [r0, #1] +; strb r1, [r0] +;foo2: +; ldrh r1, [r0, #2] +; strh r1, [r0] +;test_1x4p1: +; add r2, r0, r1, lsl #2 +; ldrb r2, [r2, #1] +; orr r2, r2, r2, lsl #8 +; str r2, [r0, r1, lsl #2] +;test_1x4p1_shl: +; add r2, r0, r1, lsl #2 +; ldrb r2, [r2, #1] +; lsl r3, r2, #8 +; orr r2, r3, r2, lsl #16 +; str r2, [r0, r1, lsl #2]