diff --git a/llvm/include/llvm/CodeGen/SelectionDAGAddressAnalysis.h b/llvm/include/llvm/CodeGen/SelectionDAGAddressAnalysis.h --- a/llvm/include/llvm/CodeGen/SelectionDAGAddressAnalysis.h +++ b/llvm/include/llvm/CodeGen/SelectionDAGAddressAnalysis.h @@ -49,6 +49,7 @@ SDValue getBase() const { return Base; } SDValue getIndex() { return Index; } SDValue getIndex() const { return Index; } + void setOffset(int64_t NewOff) { *Offset = NewOff; } bool hasValidOffset() const { return Offset.has_value(); } int64_t getOffset() const { return *Offset; } diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7815,25 +7815,28 @@ // ByteOffset is the offset of the byte in the value produced by the load. LoadSDNode *Load = nullptr; unsigned ByteOffset = 0; + unsigned VectorOffset = 0; ByteProvider() = default; - static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset) { - return ByteProvider(Load, ByteOffset); + static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset, + unsigned VectorOffset) { + return ByteProvider(Load, ByteOffset, VectorOffset); } - static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0); } + static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); } bool isConstantZero() const { return !Load; } bool isMemory() const { return Load; } bool operator==(const ByteProvider &Other) const { - return Other.Load == Load && Other.ByteOffset == ByteOffset; + return Other.Load == Load && Other.ByteOffset == ByteOffset && + Other.VectorOffset == VectorOffset; } private: - ByteProvider(LoadSDNode *Load, unsigned ByteOffset) - : Load(Load), ByteOffset(ByteOffset) {} + ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset) + : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {} }; } // end anonymous namespace @@ -7841,25 +7844,55 @@ /// Recursively traverses the expression calculating the origin of the requested /// byte of the given value. Returns None if the provider can't be calculated. /// -/// For all the values except the root of the expression verifies that the value -/// has exactly one use and if it's not true return None. This way if the origin -/// of the byte is returned it's guaranteed that the values which contribute to -/// the byte are not used outside of this expression. +/// For all the values except the root of the expression, we verify that the +/// value has exactly one use and if not then return None. This way if the +/// origin of the byte is returned it's guaranteed that the values which +/// contribute to the byte are not used outside of this expression. + +/// However, there is a special case when dealing with vector loads -- we allow +/// more than one use if the load is a vector type. Since the values that +/// contribute to the byte ultimately come from the ExtractVectorElements of the +/// Load, we don't care if the Load has uses other than ExtractVectorElements, +/// because those operations are independent from the pattern to be combined. +/// For vector loads, we simply care that the ByteProviders are adjacent +/// positions of the same vector, and their index matches the byte that is being +/// provided. This is captured by the \p VectorIndex algorithm. /// -/// Because the parts of the expression are not allowed to have more than one -/// use this function iterates over trees, not DAGs. So it never visits the same -/// node more than once. +/// The supported LoadCombine pattern for vector loads is as follows +/// or +/// / \ +/// or shl +/// / \ | +/// or shl zext +/// / \ | | +/// shl zext zext EVE* +/// | | | | +/// zext EVE* EVE* LOAD +/// | | | +/// EVE* LOAD LOAD +/// | +/// LOAD +/// +/// *ExtractVectorElement static const Optional calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, - bool Root = false) { + Optional VectorIndex, + unsigned StartingIndex = 0, bool Root = false) { + // Typical i64 by i8 pattern requires recursion up to 8 calls depth if (Depth == 10) return None; - if (!Root && !Op.hasOneUse()) + // Multiple uses of vector loads are expected + if (!Root && !Op.hasOneUse() && + !(Op.getOpcode() == ISD::LOAD && Op.getValueType().isVector())) + return None; + + // Fail to combine if we have encountered anything but a LOAD after handling + // and ExtractVectorElement. + if (Op.getOpcode() != ISD::LOAD && VectorIndex.has_value()) return None; - assert(Op.getValueType().isScalarInteger() && "can't handle other types"); unsigned BitWidth = Op.getValueSizeInBits(); if (BitWidth % 8 != 0) return None; @@ -7869,10 +7902,12 @@ switch (Op.getOpcode()) { case ISD::OR: { - auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1); + auto LHS = + calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex); if (!LHS) return None; - auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1); + auto RHS = + calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex); if (!RHS) return None; @@ -7888,6 +7923,7 @@ return None; uint64_t BitShift = ShiftOp->getZExtValue(); + if (BitShift % 8 != 0) return None; uint64_t ByteShift = BitShift / 8; @@ -7895,7 +7931,7 @@ return Index < ByteShift ? ByteProvider::getConstantZero() : calculateByteProvider(Op->getOperand(0), Index - ByteShift, - Depth + 1); + Depth + 1, VectorIndex, Index); } case ISD::ANY_EXTEND: case ISD::SIGN_EXTEND: @@ -7910,11 +7946,29 @@ return Op.getOpcode() == ISD::ZERO_EXTEND ? Optional(ByteProvider::getConstantZero()) : None; - return calculateByteProvider(NarrowOp, Index, Depth + 1); + return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, + StartingIndex); } case ISD::BSWAP: return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1, - Depth + 1); + Depth + 1, VectorIndex, StartingIndex); + case ISD::EXTRACT_VECTOR_ELT: { + auto OffsetOp = dyn_cast(Op->getOperand(1)); + if (!OffsetOp) + return None; + + VectorIndex = OffsetOp->getZExtValue(); + + // The byte we are trying to provide (StartingIndex) must correspond with + // the vector offset. Otherwise we are shuffling the elements in a vector + // and bitpicking them into a scalar, and such a pattern should not be + // combined into a load. + if (StartingIndex != OffsetOp->getZExtValue()) + return None; + + return calculateByteProvider(Op->getOperand(0), Index, Depth, VectorIndex, + StartingIndex); + } case ISD::LOAD: { auto L = cast(Op.getNode()); if (!L->isSimple() || L->isIndexed()) @@ -7929,7 +7983,9 @@ return L->getExtensionType() == ISD::ZEXTLOAD ? Optional(ByteProvider::getConstantZero()) : None; - return ByteProvider::getMemory(L, Index); + + unsigned BPVectorIndex = VectorIndex.value_or(0U); + return ByteProvider::getMemory(L, Index, BPVectorIndex); } } @@ -8221,7 +8277,15 @@ bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian(); auto MemoryByteOffset = [&] (ByteProvider P) { assert(P.isMemory() && "Must be a memory byte provider"); - unsigned LoadBitWidth = P.Load->getMemoryVT().getSizeInBits(); + unsigned LoadBitWidth; + if (P.Load->getValueType(0).isVector()) { + LoadBitWidth = + P.Load->getValueType(0).getVectorElementType().getSizeInBits(); + } + + else + LoadBitWidth = P.Load->getMemoryVT().getSizeInBits(); + assert(LoadBitWidth % 8 == 0 && "can only analyze providers for individual bytes not bit"); unsigned LoadByteWidth = LoadBitWidth / 8; @@ -8242,7 +8306,8 @@ SmallVector ByteOffsets(ByteWidth); unsigned ZeroExtendedBytes = 0; for (int i = ByteWidth - 1; i >= 0; --i) { - auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*Root=*/true); + auto P = calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ None, i, + /*Root=*/true); if (!P) return SDValue(); @@ -8256,10 +8321,6 @@ assert(P->isMemory() && "provenance should either be memory or zero"); LoadSDNode *L = P->Load; - assert(L->hasNUsesOfValue(1, 0) && L->isSimple() && - !L->isIndexed() && - "Must be enforced by calculateByteProvider"); - assert(L->getOffset().isUndef() && "Unindexed load must have undef offset"); // All loads must share the same chain SDValue LChain = L->getChain(); @@ -8271,8 +8332,13 @@ // Loads must share the same base address BaseIndexOffset Ptr = BaseIndexOffset::match(L, DAG); int64_t ByteOffsetFromBase = 0; + + // Add the VectorOffset (if any) to the offset of Ptr + Ptr.setOffset(Ptr.getOffset() + P->VectorOffset); + if (!Base) Base = Ptr; + else if (!Base->equalBaseIndex(Ptr, DAG, ByteOffsetFromBase)) return SDValue(); diff --git a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll --- a/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll +++ b/llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll @@ -9,15 +9,9 @@ ; GCN-NEXT: v_mov_b32_e32 v0, s0 ; GCN-NEXT: v_mov_b32_e32 v1, s1 ; GCN-NEXT: flat_load_dword v2, v[0:1] -; GCN-NEXT: s_mov_b32 s0, 0x6050400 ; GCN-NEXT: v_mov_b32_e32 v0, s2 ; GCN-NEXT: v_mov_b32_e32 v1, s3 ; GCN-NEXT: s_waitcnt vmcnt(0) lgkmcnt(0) -; GCN-NEXT: v_bfe_u32 v3, v2, 8, 8 -; GCN-NEXT: v_and_b32_e32 v4, 0xff0000, v2 -; GCN-NEXT: v_perm_b32 v3, v3, v2, s0 -; GCN-NEXT: v_and_b32_e32 v2, 0xff000000, v2 -; GCN-NEXT: v_or3_b32 v2, v3, v4, v2 ; GCN-NEXT: flat_store_dword v[0:1], v2 ; GCN-NEXT: s_endpgm entry: