Index: include/llvm/CodeGen/GlobalISel/MachineLegalizer.h =================================================================== --- /dev/null +++ include/llvm/CodeGen/GlobalISel/MachineLegalizer.h @@ -0,0 +1,149 @@ +//==-- llvm/CodeGen/GlobalISel/MachineLegalizer.h ----------------*- C++ -*-==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// Interface for Targets to specify which operations they can successfully +/// select and how the others should be expanded most efficiently. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_CODEGEN_GLOBALISEL_MACHINELEGALIZER_H +#define LLVM_CODEGEN_GLOBALISEL_MACHINELEGALIZER_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/CodeGen/LowLevelType.h" + +#include +#include + +namespace llvm { +class LLVMContext; +class MachineInstr; +class Type; +class VectorType; + +class MachineLegalizer { +public: + enum LegalizeAction : std::uint8_t { + /// The operation is expected to be selectable directly by the target, and + /// no transformation is necessary. + Legal, + + /// The operation should be synthesized from multiple instructions acting on + /// a narrower scalar base-type. For example a 64-bit add might be + /// implemented in terms of 32-bit add-with-carry. + NarrowScalar, + + /// The operation should be implemented in terms of a wider scalar + /// base-type. For example a <2 x s8> add could be implemented as a <2 + /// x s32> add (ignoring the high bits). + WidenScalar, + + /// The (vector) operation should be implemented by splitting it into + /// sub-vectors where the operation is legal. For example a <8 x s64> add + /// might be implemented as 4 separate <2 x s64> adds. + FewerElements, + + /// The (vector) operation should be implemented by widening the input + /// vector and ignoring the lanes added by doing so. For example <2 x i8> is + /// rarely legal, but you might perform an <8 x i8> and then only look at + /// the first two results. + MoreElements, + + /// The operation should be implemented as a call to some kind of runtime + /// support library. For example this usually happens on machines that don't + /// support floating-point operations natively. + Libcall, + + /// The target wants to do something special with this combination of + /// operand and type. A callback will be issued when it is needed. + Custom, + + /// This operation is completely unsupported on the target. A programming + /// error has occurred. + Unsupported, + }; + + MachineLegalizer(); + + /// Replace \p MI by a sequence of legal instructions that can implement the + /// same operation. Note that this means \p MI may be deleted, so any iterator + /// steps should be performed before calling this function. \p Helper should + /// be initialized to the MachineFunction containing \p MI. + /// + /// Considered as an opaque blob, the legal code will use and define the same + /// registers as \p MI. + /// + /// \returns true if the function is modified. + bool legalizeInstr(MachineInstr &MI, MachineLegalizeHelper &Helper) const; + + /// Compute any ancillary tables needed to quickly decide how an operation + /// should be handled. This must be called after all "set*Action"methods but + /// before any query is made or incorrect results may be returned. + void computeTables(); + + /// More friendly way to set an action for common types that have an LLT + /// representation. + void setAction(unsigned Opcode, LLT Ty, LegalizeAction Action) { + Actions[std::make_pair(Opcode, Ty)] = Action; + } + + /// If an operation on a given vector type (say ) isn't explicitly + /// specified, we proceed in 2 stages. First we legalize the underlying scalar + /// (so that there's at least one legal vector with that scalar), then we + /// adjust the number of elements in the vector so that it is legal. The + /// desired action in the first step is controlled by this function. + void setScalarInVectorAction(unsigned Opcode, LLT ScalarTy, + LegalizeAction Action) { + assert(!ScalarTy.isVector()); + ScalarInVectorActions[std::make_pair(Opcode, ScalarTy)] = Action; + } + + + std::pair getAction(unsigned Opcode, LLT) const; + std::pair getAction(MachineInstr &MI) const; + + /// Iterate the given function (typically something like doubling the width) + /// on Ty until we find a legal type for this operation. + LLT findLegalType(unsigned Opcode, LLT Ty, + std::function NextType) const { + LegalizeAction Action; + do { + Ty = NextType(Ty); + auto ActionIt = Actions.find(std::make_pair(Opcode, Ty)); + if (ActionIt == Actions.end()) + Action = DefaultActions.find(Opcode)->second; + else + Action = ActionIt->second; + } while(Action != Legal); + return Ty; + } + + /// Find what type it's actually OK to perform the given operation on, given + /// the general approach we've decided to take. + LLT findLegalType(unsigned Opcode, LLT Ty, LegalizeAction Action) const; + + std::pair findLegalAction(unsigned Opcode, LLT Ty, + LegalizeAction Action) const { + return std::make_pair(Action, findLegalType(Opcode, Ty, Action)); + } + + bool isLegal(MachineInstr &MI) const; + +private: + typedef DenseMap, LegalizeAction> ActionMap; + + ActionMap Actions; + ActionMap ScalarInVectorActions; + DenseMap, uint16_t> MaxLegalVectorElts; + DenseMap DefaultActions; +}; + +} // End namespace llvm. + +#endif Index: include/llvm/CodeGen/LowLevelType.h =================================================================== --- include/llvm/CodeGen/LowLevelType.h +++ include/llvm/CodeGen/LowLevelType.h @@ -47,23 +47,31 @@ }; /// \brief get a low-level scalar or aggregate "bag of bits". - static LLT scalar(int SizeInBits) { + static LLT scalar(unsigned SizeInBits) { return LLT{Scalar, 1, SizeInBits}; } /// \brief get a low-level vector of some number of elements and element - /// width. - static LLT vector(int NumElements, int ScalarSizeInBits) { + /// width. \p NumElements must be at least 2. + static LLT vector(uint16_t NumElements, unsigned ScalarSizeInBits) { assert(NumElements > 1 && "invalid number of vector elements"); return LLT{Vector, NumElements, ScalarSizeInBits}; } - /// \brif get an unsized but valid low-level type (e.g. for a label). + /// \brief get a low-level vector of some number of elements and element + /// type + static LLT vector(uint16_t NumElements, LLT ScalarTy) { + assert(NumElements > 1 && "invalid number of vector elements"); + assert(ScalarTy.isScalar() && "invalid vector element type"); + return LLT{Vector, NumElements, ScalarTy.getSizeInBits()}; + } + + /// \brief get an unsized but valid low-level type (e.g. for a label). static LLT unsized() { return LLT{Unsized, 1, 0}; } - explicit LLT(TypeKind Kind, int NumElements, int ScalarSizeInBits) + explicit LLT(TypeKind Kind, uint16_t NumElements, unsigned ScalarSizeInBits) : ScalarSize(ScalarSizeInBits), NumElements(NumElements), Kind(Kind) { assert((Kind != Vector || NumElements > 1) && "invalid number of vector elements"); @@ -72,29 +80,36 @@ explicit LLT() : ScalarSize(0), NumElements(0), Kind(Invalid) {} /// \brief construct a low-level type based on an LLVM type. - explicit LLT(Type *Ty); + explicit LLT(const Type &Ty); bool isValid() const { return Kind != Invalid; } + bool isScalar() const { return Kind == Scalar; } + bool isVector() const { return Kind == Vector; } bool isSized() const { return Kind == Scalar || Kind == Vector; } - int getNumElements() const { + /// \brief Returns the number of elements in a vector LLT. Must only be called + /// on vector types. + uint16_t getNumElements() const { assert(isVector() && "cannot get number of elements on scalar/aggregate"); return NumElements; } - int getSizeInBits() const { + /// \brief Returns the total size of the type. Must only be called on sized + /// types. + unsigned getSizeInBits() const { assert(isSized() && "attempt to get size of unsized type"); return ScalarSize * NumElements; } - int getScalarSizeInBits() const { + unsigned getScalarSizeInBits() const { assert(isSized() && "cannot get size of this type"); return ScalarSize; } + /// \brief Returns the vector's element type. Only valid for vector types. LLT getElementType() const { assert(isVector() && "cannot get element type of scalar/aggregate"); return scalar(ScalarSize); @@ -125,7 +140,7 @@ if (NumElements == 2) return scalar(ScalarSize); - return LLT{Vector, NumElements / 2, ScalarSize}; + return LLT{Vector, static_cast(NumElements / 2), ScalarSize}; } /// \brief get a low-level type with twice the size of the original, by @@ -133,7 +148,7 @@ /// source must be a vector type. For example `<2 x s32>` will become `<4 x /// s32>`. Doubling the number of elements in sN produces <2 x sN>. LLT doubleElements() const { - return LLT{Vector, NumElements * 2, ScalarSize}; + return LLT{Vector, static_cast(NumElements * 2), ScalarSize}; } void print(raw_ostream &OS) const; @@ -145,17 +160,17 @@ friend struct DenseMapInfo; private: - int ScalarSize; - int16_t NumElements; + unsigned ScalarSize; + uint16_t NumElements; TypeKind Kind; }; template<> struct DenseMapInfo { static inline LLT getEmptyKey() { - return LLT{LLT::Invalid, 0, -1}; + return LLT{LLT::Invalid, 0, -1u}; } static inline LLT getTombstoneKey() { - return LLT{LLT::Invalid, 0, -2}; + return LLT{LLT::Invalid, 0, -2u}; } static inline unsigned getHashValue(const LLT &Ty) { uint64_t Val = ((uint64_t)Ty.ScalarSize << 32) | Index: lib/CodeGen/GlobalISel/CMakeLists.txt =================================================================== --- lib/CodeGen/GlobalISel/CMakeLists.txt +++ lib/CodeGen/GlobalISel/CMakeLists.txt @@ -2,6 +2,7 @@ set(GLOBAL_ISEL_FILES IRTranslator.cpp MachineIRBuilder.cpp + MachineLegalizer.cpp RegBankSelect.cpp RegisterBank.cpp RegisterBankInfo.cpp Index: lib/CodeGen/GlobalISel/IRTranslator.cpp =================================================================== --- lib/CodeGen/GlobalISel/IRTranslator.cpp +++ lib/CodeGen/GlobalISel/IRTranslator.cpp @@ -69,7 +69,7 @@ unsigned Op0 = getOrCreateVReg(*Inst.getOperand(0)); unsigned Op1 = getOrCreateVReg(*Inst.getOperand(1)); unsigned Res = getOrCreateVReg(Inst); - MIRBuilder.buildInstr(Opcode, LLT{Inst.getType()}, Res, Op0, Op1); + MIRBuilder.buildInstr(Opcode, LLT{*Inst.getType()}, Res, Op0, Op1); return true; } @@ -88,7 +88,7 @@ if (BrInst.isUnconditional()) { const BasicBlock &BrTgt = *cast(BrInst.getOperand(0)); MachineBasicBlock &TgtBB = getOrCreateBB(BrTgt); - MIRBuilder.buildInstr(TargetOpcode::G_BR, LLT{BrTgt.getType()}, TgtBB); + MIRBuilder.buildInstr(TargetOpcode::G_BR, LLT{*BrTgt.getType()}, TgtBB); } else { assert(0 && "Not yet implemented"); } Index: lib/CodeGen/GlobalISel/MachineLegalizer.cpp =================================================================== --- /dev/null +++ lib/CodeGen/GlobalISel/MachineLegalizer.cpp @@ -0,0 +1,125 @@ +//===---- lib/CodeGen/GlobalISel/MachineLegalizer.cpp - IRTranslator -------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Implement an interface to specify and query how an illegal operation on a +// given type should be expanded. +// +// Issues to be resolved: +// + Make it fast. +// + Support weird types like i3, <7 x i3>, ... +// + Operations with more than one type (ICMP, CMPXCHG, intrinsics, ...) +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/MachineInstr.h" +#include "llvm/CodeGen/ValueTypes.h" +#include "llvm/CodeGen/GlobalISel/MachineLegalizer.h" +#include "llvm/IR/Type.h" +#include "llvm/Target/TargetOpcodes.h" +using namespace llvm; + +MachineLegalizer::MachineLegalizer() { + DefaultActions[TargetOpcode::G_ADD] = NarrowScalar; +} + +bool MachineLegalizer::legalizeInstr(MachineInstr &MI) const { + llvm_unreachable("Unimplemented functionality"); +} + +void MachineLegalizer::computeTables() { + for (auto &Op : Actions) { + LLT Ty = Op.first.second; + if (!Ty.isVector()) + continue; + + auto &Entry = + MaxLegalVectorElts[std::make_pair(Op.first.first, Ty.getElementType())]; + Entry = std::max(Entry, Ty.getNumElements()); + } +} + +// FIXME: inefficient implementation for now. Without ComputeValueVTs we're +// probably going to need specialized lookup structures for various types before +// we have any hope of doing well with something like <13 x i3>. Even the common +// cases should do better than what we have now. +std::pair +MachineLegalizer::getAction(unsigned Opcode, LLT Ty) const { + // These *have* to be implemented for now, they're the fundamental basis of + // how everything else is transformed. + + auto ActionIt = Actions.find(std::make_pair(Opcode, Ty)); + if (ActionIt != Actions.end()) + return findLegalAction(Opcode, Ty, ActionIt->second); + + if (!Ty.isVector()) { + auto DefaultAction = DefaultActions.find(Opcode); + if (DefaultAction != DefaultActions.end() && DefaultAction->second == Legal) + return std::make_pair(Legal, Ty); + + assert(DefaultAction->second == NarrowScalar && "unexpected default"); + return findLegalAction(Opcode, Ty, NarrowScalar); + } + + LLT EltTy = Ty.getElementType(); + int NumElts = Ty.getNumElements(); + + auto ScalarAction = ScalarInVectorActions.find(std::make_pair(Opcode, EltTy)); + if (ScalarAction != ScalarInVectorActions.end() && + ScalarAction->second != Legal) + return findLegalAction(Opcode, EltTy, ScalarAction->second); + + // The element type is legal in principle, but the number of elements is + // wrong. + auto MaxLegalElts = MaxLegalVectorElts.lookup(std::make_pair(Opcode, EltTy)); + if (MaxLegalElts > NumElts) + return findLegalAction(Opcode, Ty, MoreElements); + + if (MaxLegalElts == 0) { + // Scalarize if there's no legal vector type, which is just a special case + // of FewerElements. + return std::make_pair(FewerElements, EltTy); + } + + return findLegalAction(Opcode, Ty, FewerElements); +} + +std::pair +MachineLegalizer::getAction(MachineInstr &MI) const { + return getAction(MI.getOpcode(), MI.getType()); +} + +bool MachineLegalizer::isLegal(MachineInstr &MI) const { + return getAction(MI).first == Legal; +} + +LLT MachineLegalizer::findLegalType(unsigned Opcode, LLT Ty, + LegalizeAction Action) const { + switch(Action) { + default: + llvm_unreachable("Cannot find legal type"); + case Legal: + return Ty; + case NarrowScalar: { + return findLegalType(Opcode, Ty, + [&](LLT Ty) -> LLT { return Ty.halfScalarSize(); }); + } + case WidenScalar: { + return findLegalType(Opcode, Ty, + [&](LLT Ty) -> LLT { return Ty.doubleScalarSize(); }); + } + case FewerElements: { + return findLegalType(Opcode, Ty, + [&](LLT Ty) -> LLT { return Ty.halfElements(); }); + } + case MoreElements: { + return findLegalType( + Opcode, Ty, [&](LLT Ty) -> LLT { return Ty.doubleElements(); }); + } + } +} Index: lib/CodeGen/LowLevelType.cpp =================================================================== --- lib/CodeGen/LowLevelType.cpp +++ lib/CodeGen/LowLevelType.cpp @@ -17,16 +17,16 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -LLT::LLT(Type *Ty) { - if (VectorType *VTy = dyn_cast(Ty)) { +LLT::LLT(const Type &Ty) { + if (auto VTy = dyn_cast(&Ty)) { ScalarSize = VTy->getElementType()->getPrimitiveSizeInBits(); NumElements = VTy->getNumElements(); Kind = NumElements == 1 ? Scalar : Vector; - } else if (Ty->isSized()) { + } else if (Ty.isSized()) { // Aggregates are no different from real scalars as far as GlobalISel is // concerned. Kind = Scalar; - ScalarSize = Ty->getPrimitiveSizeInBits(); + ScalarSize = Ty.getPrimitiveSizeInBits(); NumElements = 1; } else { Kind = Unsized; Index: unittests/CodeGen/CMakeLists.txt =================================================================== --- unittests/CodeGen/CMakeLists.txt +++ unittests/CodeGen/CMakeLists.txt @@ -10,3 +10,5 @@ add_llvm_unittest(CodeGenTests ${CodeGenSources} ) + +add_subdirectory(GlobalISel) Index: unittests/CodeGen/GlobalISel/CMakeLists.txt =================================================================== --- /dev/null +++ unittests/CodeGen/GlobalISel/CMakeLists.txt @@ -0,0 +1,9 @@ +set(LLVM_LINK_COMPONENTS + GlobalISel + ) + +if(LLVM_BUILD_GLOBAL_ISEL) + add_llvm_unittest(GlobalISelTests + MachineLegalizerTest.cpp + ) +endif() Index: unittests/CodeGen/GlobalISel/MachineLegalizerTest.cpp =================================================================== --- /dev/null +++ unittests/CodeGen/GlobalISel/MachineLegalizerTest.cpp @@ -0,0 +1,102 @@ +//===- llvm/unittest/CodeGen/GlobalISel/MachineLegalizerTest.cpp ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/CodeGen/GlobalISel/MachineLegalizer.h" +#include "llvm/Target/TargetOpcodes.h" +#include "gtest/gtest.h" + +using namespace llvm; +using llvm::MachineLegalizer::LegalizeAction::Legal; +using llvm::MachineLegalizer::LegalizeAction::NarrowScalar; +using llvm::MachineLegalizer::LegalizeAction::WidenScalar; +using llvm::MachineLegalizer::LegalizeAction::FewerElements; +using llvm::MachineLegalizer::LegalizeAction::MoreElements; +using llvm::MachineLegalizer::LegalizeAction::Libcall; +using llvm::MachineLegalizer::LegalizeAction::Custom; +using llvm::MachineLegalizer::LegalizeAction::Unsupported; + +// Define a couple of pretty printers to help debugging when things go wrong. +namespace llvm { +std::ostream & +operator<<(std::ostream &OS, const llvm::MachineLegalizer::LegalizeAction Act) { + switch (Act) { + case Legal: OS << "Legal"; break; + case NarrowScalar: OS << "NarrowScalar"; break; + case WidenScalar: OS << "WidenScalar"; break; + case FewerElements: OS << "FewerElements"; break; + case MoreElements: OS << "MoreElements"; break; + case Libcall: OS << "Libcall"; break; + case Custom: OS << "Custom"; break; + case Unsupported: OS << "Unsupported"; break; + } + return OS; +} + +std::ostream & +operator<<(std::ostream &OS, const llvm::LLT Ty) { + std::string Repr; + raw_string_ostream SS{Repr}; + Ty.print(SS); + OS << SS.str(); + return OS; +} +} + +namespace { + + +TEST(MachineLegalizerTest, ScalarRISC) { + MachineLegalizer L; + // Typical RISCy set of operations based on AArch64. + L.setAction(TargetOpcode::G_ADD, LLT::scalar(8), WidenScalar); + L.setAction(TargetOpcode::G_ADD, LLT::scalar(16), WidenScalar); + L.setAction(TargetOpcode::G_ADD, LLT::scalar(32), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::scalar(64), Legal); + L.computeTables(); + + // Check we infer the correct types and actually do what we're told. + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::scalar(8)), + std::make_pair(WidenScalar, LLT::scalar(32))); + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::scalar(16)), + std::make_pair(WidenScalar, LLT::scalar(32))); + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::scalar(32)), + std::make_pair(Legal, LLT::scalar(32))); + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::scalar(64)), + std::make_pair(Legal, LLT::scalar(64))); + + // Make sure the default for over-sized types applies. + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::scalar(128)), + std::make_pair(NarrowScalar, LLT::scalar(64))); +} + +TEST(MachineLegalizerTest, VectorRISC) { + MachineLegalizer L; + // Typical RISCy set of operations based on ARM. + L.setScalarInVectorAction(TargetOpcode::G_ADD, LLT::scalar(8), Legal); + L.setScalarInVectorAction(TargetOpcode::G_ADD, LLT::scalar(16), Legal); + L.setScalarInVectorAction(TargetOpcode::G_ADD, LLT::scalar(32), Legal); + + L.setAction(TargetOpcode::G_ADD, LLT::vector(8, 8), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::vector(16, 8), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::vector(4, 16), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::vector(8, 16), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::vector(2, 32), Legal); + L.setAction(TargetOpcode::G_ADD, LLT::vector(4, 32), Legal); + L.computeTables(); + + // Check we infer the correct types and actually do what we're told for some + // simple cases. + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::vector(2, 8)), + std::make_pair(MoreElements, LLT::vector(8, 8))); + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::vector(8, 8)), + std::make_pair(Legal, LLT::vector(8, 8))); + ASSERT_EQ(L.getAction(TargetOpcode::G_ADD, LLT::vector(8, 32)), + std::make_pair(FewerElements, LLT::vector(4, 32))); +} +}