diff --git a/llvm/include/llvm/CodeGen/ByteProvider.h b/llvm/include/llvm/CodeGen/ByteProvider.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/CodeGen/ByteProvider.h @@ -0,0 +1,90 @@ +//===-- include/llvm/CodeGen/ByteProvider.h - Map bytes ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// \file +// This file implements ByteProvider. The purpose of ByteProvider is to provide +// a map between a target node's byte (byte position is DestOffset) and the +// source (and byte position) that provides it (in Src and SrcOffset +// respectively) See CodeGen/SelectionDAG/DAGCombiner.cpp MatchLoadCombine +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_BYTEPROVIDER_H +#define LLVM_CODEGEN_BYTEPROVIDER_H + +#include +#include + +namespace llvm { + +/// Represents known origin of an individual byte in combine pattern. The +/// value of the byte is either constant zero, or comes from memory / +/// some other productive instruction (e.g. arithmetic instructions). +/// Bit manipulation instructions like shifts are not ByteProviders, rather +/// are used to extract Bytes. +template class ByteProvider { +private: + ByteProvider(std::optional Src, int64_t DestOffset, + int64_t SrcOffset) + : Src(Src), DestOffset(DestOffset), SrcOffset(SrcOffset) {} + + // TODO -- use constraint in c++20 + // Does this type correspond with an operation in selection DAG + template class is_op { + private: + using yes = std::true_type; + using no = std::false_type; + + // Only allow classes with member function getOpcode + template + static auto test(int) -> decltype(std::declval().getOpcode(), yes()); + + template static no test(...); + + public: + using remove_pointer_t = typename std::remove_pointer::type; + static constexpr bool value = + std::is_same(0)), yes>::value; + }; + +public: + // For constant zero providers Src is set to nullopt. For actual providers + // Src represents the node which originally produced the relevant bits. + std::optional Src = std::nullopt; + // DestOffset is the offset of the byte in the dest we are trying to map for. + int64_t DestOffset = 0; + // SrcOffset is the offset in the ultimate source node that maps to the + // DestOffset + int64_t SrcOffset = 0; + + ByteProvider() = default; + + static ByteProvider getSrc(std::optional Val, int64_t ByteOffset, + int64_t VectorOffset) { + static_assert(is_op().value, + "ByteProviders must contain an operation in selection DAG."); + return ByteProvider(Val, ByteOffset, VectorOffset); + } + + static ByteProvider getConstantZero() { + return ByteProvider(std::nullopt, 0, 0); + } + bool isConstantZero() const { return !Src; } + + bool hasSrc() const { return Src.has_value(); } + + bool hasSameSrc(const ByteProvider &Other) const { return Other.Src == Src; } + + bool operator==(const ByteProvider &Other) const { + return Other.Src == Src && Other.DestOffset == DestOffset && + Other.SrcOffset == SrcOffset; + } +}; +} // end namespace llvm + +#endif // LLVM_CODEGEN_BYTEPROVIDER_H diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -32,6 +32,7 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/CodeGen/ByteProvider.h" #include "llvm/CodeGen/DAGCombine.h" #include "llvm/CodeGen/ISDOpcodes.h" #include "llvm/CodeGen/MachineFunction.h" @@ -8402,42 +8403,6 @@ return SDValue(); } -namespace { - -/// Represents known origin of an individual byte in load combine pattern. The -/// value of the byte is either constant zero or comes from memory. -struct ByteProvider { - // For constant zero providers Load is set to nullptr. For memory providers - // Load represents the node which loads the byte from memory. - // ByteOffset is the offset of the byte in the value produced by the load. - LoadSDNode *Load = nullptr; - unsigned ByteOffset = 0; - unsigned VectorOffset = 0; - - ByteProvider() = default; - - static ByteProvider getMemory(LoadSDNode *Load, unsigned ByteOffset, - unsigned VectorOffset) { - return ByteProvider(Load, ByteOffset, VectorOffset); - } - - static ByteProvider getConstantZero() { return ByteProvider(nullptr, 0, 0); } - - bool isConstantZero() const { return !Load; } - bool isMemory() const { return Load; } - - bool operator==(const ByteProvider &Other) const { - return Other.Load == Load && Other.ByteOffset == ByteOffset && - Other.VectorOffset == VectorOffset; - } - -private: - ByteProvider(LoadSDNode *Load, unsigned ByteOffset, unsigned VectorOffset) - : Load(Load), ByteOffset(ByteOffset), VectorOffset(VectorOffset) {} -}; - -} // end anonymous namespace - /// Recursively traverses the expression calculating the origin of the requested /// byte of the given value. Returns std::nullopt if the provider can't be /// calculated. @@ -8479,7 +8444,9 @@ /// LOAD /// /// *ExtractVectorElement -static const std::optional +using SDByteProvider = ByteProvider; + +static const std::optional calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth, std::optional VectorIndex, unsigned StartingIndex = 0) { @@ -8538,7 +8505,7 @@ // provide, then do not provide anything. Otherwise, subtract the index by // the amount we shifted by. return Index < ByteShift - ? ByteProvider::getConstantZero() + ? SDByteProvider::getConstantZero() : calculateByteProvider(Op->getOperand(0), Index - ByteShift, Depth + 1, VectorIndex, Index); } @@ -8553,7 +8520,8 @@ if (Index >= NarrowByteWidth) return Op.getOpcode() == ISD::ZERO_EXTEND - ? std::optional(ByteProvider::getConstantZero()) + ? std::optional( + SDByteProvider::getConstantZero()) : std::nullopt; return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex, StartingIndex); @@ -8603,11 +8571,12 @@ // question if (Index >= NarrowByteWidth) return L->getExtensionType() == ISD::ZEXTLOAD - ? std::optional(ByteProvider::getConstantZero()) + ? std::optional( + SDByteProvider::getConstantZero()) : std::nullopt; unsigned BPVectorIndex = VectorIndex.value_or(0U); - return ByteProvider::getMemory(L, Index, BPVectorIndex); + return SDByteProvider::getSrc(L, Index, BPVectorIndex); } } @@ -8901,23 +8870,24 @@ unsigned ByteWidth = VT.getSizeInBits() / 8; bool IsBigEndianTarget = DAG.getDataLayout().isBigEndian(); - auto MemoryByteOffset = [&] (ByteProvider P) { - assert(P.isMemory() && "Must be a memory byte provider"); - unsigned LoadBitWidth = P.Load->getMemoryVT().getScalarSizeInBits(); + auto MemoryByteOffset = [&](SDByteProvider P) { + assert(P.hasSrc() && "Must be a memory byte provider"); + auto *Load = cast(P.Src.value()); + + unsigned LoadBitWidth = Load->getMemoryVT().getScalarSizeInBits(); assert(LoadBitWidth % 8 == 0 && "can only analyze providers for individual bytes not bit"); unsigned LoadByteWidth = LoadBitWidth / 8; - return IsBigEndianTarget - ? bigEndianByteAt(LoadByteWidth, P.ByteOffset) - : littleEndianByteAt(LoadByteWidth, P.ByteOffset); + return IsBigEndianTarget ? bigEndianByteAt(LoadByteWidth, P.DestOffset) + : littleEndianByteAt(LoadByteWidth, P.DestOffset); }; std::optional Base; SDValue Chain; SmallPtrSet Loads; - std::optional FirstByteProvider; + std::optional FirstByteProvider; int64_t FirstOffset = INT64_MAX; // Check if all the bytes of the OR we are looking at are loaded from the same @@ -8938,9 +8908,8 @@ return SDValue(); continue; } - assert(P->isMemory() && "provenance should either be memory or zero"); - - LoadSDNode *L = P->Load; + assert(P->hasSrc() && "provenance should either be memory or zero"); + auto *L = cast(P->Src.value()); // All loads must share the same chain SDValue LChain = L->getChain(); @@ -8964,7 +8933,7 @@ unsigned LoadWidthInBit = L->getMemoryVT().getScalarSizeInBits(); if (LoadWidthInBit % 8 != 0) return SDValue(); - unsigned ByteOffsetFromVector = P->VectorOffset * LoadWidthInBit / 8; + unsigned ByteOffsetFromVector = P->SrcOffset * LoadWidthInBit / 8; Ptr.addToOffset(ByteOffsetFromVector); } @@ -9021,7 +8990,7 @@ // So the combined value can be loaded from the first load address. if (MemoryByteOffset(*FirstByteProvider) != 0) return SDValue(); - LoadSDNode *FirstLoad = FirstByteProvider->Load; + auto *FirstLoad = cast(FirstByteProvider->Src.value()); // The node we are looking at matches with the pattern, check if we can // replace it with a single (possibly zero-extended) load and bswap + shift if