diff --git a/llvm/include/llvm/ADT/ByteProvider.h b/llvm/include/llvm/ADT/ByteProvider.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/ADT/ByteProvider.h @@ -0,0 +1,92 @@ +//===-- llvm/ADT/ByteProvider.h - Map bytes from dest to source -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements ByteProvider. The purpose of ByteProvider is to provide +// a map between a target node's byte (byte position is DestOffset) and the +// source (and byte position) that provides it (in Src and SrcOffset +// respectively) See CodeGen/SelectionDAG/DAGCombiner.cpp MatchLoadCombine +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_BYTEPROVIDER_H +#define LLVM_ADT_BYTEPROVIDER_H + +#include +#include + +namespace { +// TODO -- use constraint in c++20 +// Does this type correspond with an operation in selection DAG +template struct is_op { +private: + typedef std::true_type yes; + typedef std::false_type no; + + // Only allow classes with member function getOpcode + template + static auto test(int) -> decltype(std::declval().getOpcode(), yes()); + + template static no test(...); + +public: + using remove_pointer_t = typename std::remove_pointer::type; + static constexpr bool value = + std::is_same(0)), yes>::value; +}; +} // end namespace + +namespace llvm { + +/// Represents known origin of an individual byte in combine pattern. The +/// value of the byte is either constant zero, or comes from memory / +/// some other productive instruction (e.g. arithmetic instructions). +/// Bit manipulation instructions like shifts are not ByteProviders, rather +/// are used to extract Bytes. +template struct ByteProvider { + // For constant zero providers Src is set to nullopt. For actual providers + // Src represents the node which originally produced the relevant bits. + std::optional Src = std::nullopt; + // DestOffset is the offset of the byte in the dest we are trying to map for. + unsigned DestOffset = 0; + // SrcOffset is the offset in the ultimate source node that maps to the + // DestOffset + unsigned SrcOffset = 0; + + ByteProvider() = default; + + static ByteProvider getSrc(std::optional Val, unsigned ByteOffset, + unsigned VectorOffset) { + static_assert( + is_op().value, + "ByteProviders must correspond with an operation in selection DAG."); + return ByteProvider(Val, ByteOffset, VectorOffset); + } + + static ByteProvider getConstantZero() { + return ByteProvider(std::nullopt, 0, 0); + } + bool isConstantZero() const { return !Src; } + + bool hasSrc() const { return Src.has_value(); } + + bool hasSameSrc(const ByteProvider &Other) const { return Other.Src == Src; } + + bool operator==(const ByteProvider &Other) const { + return Other.Src == Src && Other.DestOffset == DestOffset && + Other.SrcOffset == SrcOffset; + } + +private: + ByteProvider(std::optional Src, unsigned DestOffset, unsigned SrcOffset) + : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {} +}; + +} // end namespace llvm + +#endif // LLVM_ADT_BYTEPROVIDER_H \ No newline at end of file diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/ByteProvider.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/IntervalMap.h" #include "llvm/ADT/None.h" @@ -7779,42 +7780,6 @@ return SDValue(); } -namespace { - -/// Represents known origin of an individual byte in load combine pattern. The -/// value of the byte is either constant zero or comes from memory. -struct ByteProvider { - // For constant zero providers Load is set to nullptr. For memory providers - // Load represents the node which loads the byte from memory. - // ByteOffset is the offset of the byte in the value produced by the load. - LoadSDNode *Load = nullptr; - unsigned ByteOffset = 0; - unsigned VectorOffset = 0; - - ByteProvider() = default; - - static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset, - unsigned VectorOffset) { - return ByteProvider(Load, ByteOffset, VectorOffset); - } - - static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); } - - bool isConstantZero() const { return !Load; } - bool isMemory() const { return Load; } - - bool operator==(const ByteProvider &Other) const { - return Other.Load == Load && Other.ByteOffset == ByteOffset && - Other.VectorOffset == VectorOffset; - } - -private: - ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset) - : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {} -}; - -} // end anonymous namespace - /// Recursively traverses the expression calculating the origin of the requested /// byte of the given value. Returns None if the provider can't be calculated. /// @@ -7855,11 +7820,10 @@ /// LOAD /// /// *ExtractVectorElement -static const Optional -calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, +static const Optional> +calculateByteProvider(const SDValue &Op, unsigned Index, unsigned Depth, Optional VectorIndex, unsigned StartingIndex = 0) { - // Typical i64 by i8 pattern requires recursion up to 8 calls depth if (Depth == 10) return None; @@ -7914,7 +7878,7 @@ // provide, then do not provide anything. Otherwise, subtract the index by // the amount we shifted by. return Index < ByteShift - ? ByteProvider::getConstantZero() + ? ByteProvider::getConstantZero() : calculateByteProvider(Op->getOperand(0), Index - ByteShift, Depth + 1, VectorIndex, Index); } @@ -7929,7 +7893,8 @@ if (Index >= NarrowByteWidth) return Op.getOpcode() == ISD::ZERO_EXTEND - ? Optional(ByteProvider::getConstantZero()) + ? Optional>( + ByteProvider::getConstantZero()) : None; return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, StartingIndex); @@ -7979,11 +7944,12 @@ // question if (Index >= NarrowByteWidth) return L->getExtensionType() == ISD::ZEXTLOAD - ? Optional(ByteProvider::getConstantZero()) + ? Optional>( + ByteProvider::getConstantZero()) : None; unsigned BPVectorIndex = VectorIndex.value_or(0U); - return ByteProvider::getMemory(L, Index, BPVectorIndex); + return ByteProvider::getSrc(L, Index, BPVectorIndex); } } @@ -8273,23 +8239,25 @@ unsigned ByteWidth = VT.getSizeInBits() / 8; bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian(); - auto MemoryByteOffset = [&] (ByteProvider P) { - assert(P.isMemory() && "Must be a memory byte provider"); - unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits(); + auto MemoryByteOffset = [&](ByteProvider P) { + assert(P.hasSrc() && "Must be a memory byte provider"); + assert(isa(P.Src.value())); + LoadSDNode *Load = cast(P.Src.value()); + + unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits(); assert(LoadBitWidth % 8 == 0 && "can only analyze providers for individual bytes not bit"); unsigned LoadByteWidth = LoadBitWidth / 8; - return IsBigEndianTarget - ? bigEndianByteAt(LoadByteWidth, P.ByteOffset) - : littleEndianByteAt(LoadByteWidth, P.ByteOffset); + return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset) + : littleEndianByteAt(LoadByteWidth, P.DestOffset); }; Optional Base; SDValue Chain; SmallPtrSet Loads; - Optional FirstByteProvider; + Optional> FirstByteProvider; int64_t FirstOffset = INT64_MAX; // Check if all the bytes of the OR we are looking at are loaded from the same @@ -8309,9 +8277,10 @@ return SDValue(); continue; } - assert(P->isMemory() && "provenance should either be memory or zero"); - LoadSDNode *L = P->Load; + assert(P->hasSrc() && "provenance should either be memory or zero"); + assert(isa(P->Src.value())); + LoadSDNode *L = cast(P->Src.value()); // All loads must share the same chain SDValue LChain = L->getChain(); @@ -8335,7 +8304,7 @@ unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits(); if (LoadWidthInBit % 8 != 0) return SDValue(); - unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8; + unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8; Ptr.addToOffset(ByteOffsetFromVector); } @@ -8392,7 +8361,8 @@ // So the combined value can be loaded from the first load address. if (MemoryByteOffset(*FirstByteProvider) != 0) return SDValue(); - LoadSDNode *FirstLoad = FirstByteProvider->Load; + assert(isa(FirstByteProvider->Src.value())); + LoadSDNode *FirstLoad = cast(FirstByteProvider->Src.value()); // The node we are looking at matches with the pattern, check if we can // replace it with a single (possibly zero-extended) load and bswap + shift if