diff --git a/llvm/include/llvm/IR/LLVMContext.h b/llvm/include/llvm/IR/LLVMContext.h --- a/llvm/include/llvm/IR/LLVMContext.h +++ b/llvm/include/llvm/IR/LLVMContext.h @@ -24,6 +24,7 @@ namespace llvm { +class Any; class DiagnosticInfo; enum DiagnosticSeverity : char; class Function; @@ -315,6 +316,10 @@ /// Whether typed pointers are supported. If false, all pointers are opaque. bool supportsTypedPointers() const; + /// Optionally target-spcific data can be attached to the context for lifetime + /// management and bypassing layering restrictions. + llvm::Any &getTargetData() const; + private: // Module needs access to the add/removeModule methods. friend class Module; diff --git a/llvm/include/llvm/IR/Type.h b/llvm/include/llvm/IR/Type.h --- a/llvm/include/llvm/IR/Type.h +++ b/llvm/include/llvm/IR/Type.h @@ -68,13 +68,14 @@ TokenTyID, ///< Tokens // Derived types... see DerivedTypes.h file. - IntegerTyID, ///< Arbitrary bit width integers - FunctionTyID, ///< Functions - PointerTyID, ///< Pointers - StructTyID, ///< Structures - ArrayTyID, ///< Arrays - FixedVectorTyID, ///< Fixed width SIMD vector type - ScalableVectorTyID ///< Scalable SIMD vector type + IntegerTyID, ///< Arbitrary bit width integers + FunctionTyID, ///< Functions + PointerTyID, ///< Pointers + StructTyID, ///< Structures + ArrayTyID, ///< Arrays + FixedVectorTyID, ///< Fixed width SIMD vector type + ScalableVectorTyID, ///< Scalable SIMD vector type + DXILPointerTyID, ///< DXIL typed pointer used by DirectX target }; private: diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp --- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp +++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp @@ -714,6 +714,8 @@ TypeVals.push_back(true); break; } + case Type::DXILPointerTyID: + llvm_unreachable("DXIL pointers cannot be added to IR modules"); } // Emit the finished record. diff --git a/llvm/lib/IR/AsmWriter.cpp b/llvm/lib/IR/AsmWriter.cpp --- a/llvm/lib/IR/AsmWriter.cpp +++ b/llvm/lib/IR/AsmWriter.cpp @@ -612,6 +612,9 @@ OS << '>'; return; } + case Type::DXILPointerTyID: + OS << "dxil-ptr (" << Ty << ")"; + return; } llvm_unreachable("Invalid TypeID"); } diff --git a/llvm/lib/IR/Core.cpp b/llvm/lib/IR/Core.cpp --- a/llvm/lib/IR/Core.cpp +++ b/llvm/lib/IR/Core.cpp @@ -534,6 +534,8 @@ return LLVMTokenTypeKind; case Type::ScalableVectorTyID: return LLVMScalableVectorTypeKind; + case Type::DXILPointerTyID: + llvm_unreachable("DXIL pointers are unsupported via the C API"); } llvm_unreachable("Unhandled TypeID."); } diff --git a/llvm/lib/IR/LLVMContext.cpp b/llvm/lib/IR/LLVMContext.cpp --- a/llvm/lib/IR/LLVMContext.cpp +++ b/llvm/lib/IR/LLVMContext.cpp @@ -364,3 +364,7 @@ bool LLVMContext::supportsTypedPointers() const { return !pImpl->getOpaquePointers(); } + +Any &LLVMContext::getTargetData() const { + return pImpl->TargetDataStorage; +} diff --git a/llvm/lib/IR/LLVMContextImpl.h b/llvm/lib/IR/LLVMContextImpl.h --- a/llvm/lib/IR/LLVMContextImpl.h +++ b/llvm/lib/IR/LLVMContextImpl.h @@ -17,6 +17,7 @@ #include "ConstantsContext.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/Any.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" @@ -1558,6 +1559,8 @@ bool hasOpaquePointersValue(); void setOpaquePointers(bool OP); + llvm::Any TargetDataStorage; + private: Optional OpaquePointers; }; diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -10,7 +10,9 @@ DirectXSubtarget.cpp DirectXTargetMachine.cpp DXILBitcodeWriter.cpp + DXILPointerType.cpp DXILPrepare.cpp + PointerTypeAnalysis.cpp LINK_COMPONENTS Bitwriter diff --git a/llvm/lib/Target/DirectX/DXILPointerType.h b/llvm/lib/Target/DirectX/DXILPointerType.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILPointerType.h @@ -0,0 +1,52 @@ +//===- Target/DirectX/DXILPointerType.h - DXIL Typed Pointer Type ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H +#define LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H + +#include "llvm/IR/Type.h" + +namespace llvm { +namespace dxil { + +// DXIL has typed pointers, this pointer type abstraction is used for tracking +// in PointerTypeAnalysis and for the bitcode ValueEnumerator +class TypedPointerType : public Type { + explicit TypedPointerType(Type *ElType, unsigned AddrSpace); + + Type *PointeeTy; + +public: + TypedPointerType(const TypedPointerType &) = delete; + TypedPointerType &operator=(const TypedPointerType &) = delete; + + /// This constructs a pointer to an object of the specified type in a numbered + /// address space. + static TypedPointerType *get(Type *ElementType, unsigned AddressSpace); + + /// Return true if the specified type is valid as a element type. + static bool isValidElementType(Type *ElemTy); + + /// Return the address space of the Pointer type. + inline unsigned getAddressSpace() const { return getSubclassData(); } + + Type *getElementType() const { return PointeeTy; } + + /// Implement support type inquiry through isa, cast, and dyn_cast. + static bool classof(const Type *T) { + return T->getTypeID() == DXILPointerTyID; + } +}; + +} // namespace dxil +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_DXILPOINTERTYPE_H diff --git a/llvm/lib/Target/DirectX/DXILPointerType.cpp b/llvm/lib/Target/DirectX/DXILPointerType.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILPointerType.cpp @@ -0,0 +1,66 @@ +//===- Target/DirectX/DXILTypedPointerType.cpp - DXIL Typed Pointer Type +//-------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// +//===----------------------------------------------------------------------===// + +#include "DXILPointerType.h" +#include "llvm/ADT/Any.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/IR/LLVMContext.h" + +using namespace llvm; +using namespace llvm::dxil; + +class TypedPointerTracking { +public: + TypedPointerTracking() {} + DenseMap> PointerTypes; + DenseMap, std::unique_ptr> + ASPointerTypes; +}; + +TypedPointerType *TypedPointerType::get(Type *EltTy, unsigned AddressSpace) { + assert(EltTy && "Can't get a pointer to type!"); + assert(isValidElementType(EltTy) && "Invalid type for pointer element!"); + + llvm::Any &TargetData = EltTy->getContext().getTargetData(); + if (!TargetData.hasValue()) + TargetData = Any{std::make_shared()}; + + assert(any_isa>(TargetData) && + "Unexpected target data type"); + + std::shared_ptr Tracking = + any_cast>(TargetData); + + // Since AddressSpace #0 is the common case, we special case it. + std::unique_ptr &Entry = + AddressSpace == 0 + ? Tracking->PointerTypes[EltTy] + : Tracking->ASPointerTypes[std::make_pair(EltTy, AddressSpace)]; + + if (!Entry) + Entry = std::unique_ptr( + new TypedPointerType(EltTy, AddressSpace)); + return Entry.get(); +} + +TypedPointerType::TypedPointerType(Type *E, unsigned AddrSpace) + : Type(E->getContext(), DXILPointerTyID), PointeeTy(E) { + ContainedTys = &PointeeTy; + NumContainedTys = 1; + setSubclassData(AddrSpace); +} + +bool TypedPointerType::isValidElementType(Type *ElemTy) { + return !ElemTy->isVoidTy() && !ElemTy->isLabelTy() && + !ElemTy->isMetadataTy() && !ElemTy->isTokenTy() && + !ElemTy->isX86_AMXTy(); +} diff --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.h b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h @@ -0,0 +1,39 @@ +//===- Target/DirectX/PointerTypeAnalisis.h - PointerType analysis --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Analysis pass to assign types to opaque pointers. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H +#define LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H + +#include "DXILPointerType.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +// Store the underlying type and the number of pointer indirections +using PointerTypeMap = DenseMap; + +/// An analysis to compute the \c PointerTypes for pointers in a \c Module. +/// Since this analysis is only run during codegen and the new pass manager +/// doesn't support codegen passes, this is wrtten as a function in a namespace. +/// It is very simple to transform it into a proper analysis pass. +/// This code relies on typed pointers existing as LLVM types, but could be +/// migrated to a custom Type if PointerType loses typed support. +namespace PointerTypeAnalysis { + +/// Compute the \c PointerTypeMap for the module \c M. +PointerTypeMap run(const Module &M); +} // namespace PointerTypeAnalysis + +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H diff --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp @@ -0,0 +1,110 @@ +//===- Target/DirectX/PointerTypeAnalisis.cpp - PointerType analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Analysis pass to assign types to opaque pointers. +// +//===----------------------------------------------------------------------===// + +#include "PointerTypeAnalysis.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; +using namespace llvm::dxil; + +namespace { +TypedPointerType *classifyPointerType(const Value *V) { + assert(V->getType()->isOpaquePointerTy() && + "classifyPointerType called with non-opaque pointer"); + Type *PointeeTy = nullptr; + if (auto *Inst = dyn_cast(V)) { + PointeeTy = Inst->getResultElementType()->isOpaquePointerTy() + ? nullptr + : Inst->getResultElementType(); + } else if (auto *Inst = dyn_cast(V)) { + PointeeTy = Inst->getAllocatedType(); + } + for (auto User : V->users()) { + Type *NewPointeeTy = nullptr; + if (auto *Inst = dyn_cast(User)) { + NewPointeeTy = Inst->getType(); + } else if (auto *Inst = dyn_cast(User)) { + NewPointeeTy = Inst->getValueOperand()->getType(); + } else if (auto *Inst = dyn_cast(User)) { + NewPointeeTy = Inst->getSourceElementType(); + } + if (NewPointeeTy) { + if (NewPointeeTy->isOpaquePointerTy()) + return TypedPointerType::get(classifyPointerType(User), + V->getType()->getPointerAddressSpace()); + if (!PointeeTy) + PointeeTy = NewPointeeTy; + else if (PointeeTy != NewPointeeTy) + PointeeTy = Type::getInt8Ty(V->getContext()); + } + } + // If we were unable to determine the pointee type, set to i8 + if (!PointeeTy) + PointeeTy = Type::getInt8Ty(V->getContext()); + return TypedPointerType::get(PointeeTy, + V->getType()->getPointerAddressSpace()); +} + +void handleFunction(const Function &F, PointerTypeMap &Map) { + SmallVector NewArgs; + bool HasOpaqueTy = false; + Type *RetTy = F.getReturnType(); + if (RetTy->isOpaquePointerTy()) { + RetTy = nullptr; + for (auto &B : F) { + for (auto &I : B) { + if (auto *RetInst = dyn_cast_or_null(&I)) { + Type *NewRetTy = classifyPointerType(RetInst->getReturnValue()); + if (!RetTy) + RetTy = NewRetTy; + else if (RetTy != NewRetTy) + RetTy = TypedPointerType::get( + Type::getInt8Ty(I.getContext()), + F.getReturnType()->getPointerAddressSpace()); + } + } + } + } + for (auto &A : F.args()) { + Type *ArgTy = A.getType(); + if (ArgTy->isOpaquePointerTy()) { + TypedPointerType *NewTy = classifyPointerType(&A); + Map[&A] = NewTy; + ArgTy = NewTy; + HasOpaqueTy = true; + } + NewArgs.push_back(ArgTy); + } + if (!HasOpaqueTy) + return; + Map[&F] = FunctionType::get(RetTy, NewArgs, false); +} +} // anonymous namespace + +PointerTypeMap PointerTypeAnalysis::run(const Module &M) { + PointerTypeMap Map; + for (auto &G : M.globals()) { + if (G.getType()->isOpaquePointerTy()) + Map[&G] = classifyPointerType(&G); + } + for (auto &F : M) { + handleFunction(F, Map); + + for (auto &B : F) { + for (auto &I : B) { + if (I.getType()->isOpaquePointerTy()) + Map[&I] = classifyPointerType(&I); + } + } + } + return Map; +} diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/DirectX/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( + ${LLVM_MAIN_SRC_DIR}/lib/Target/DirectX + ${LLVM_BINARY_DIR}/lib/Target/DirectX + ) + +set(LLVM_LINK_COMPONENTS + AsmParser + Core + DirectXCodeGen + Support +) + +add_llvm_target_unittest(DirectXTests + PointerTypeAnalysisTests.cpp + ) diff --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp @@ -0,0 +1,208 @@ +//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DXILPointerType.h" +#include "PointerTypeAnalysis.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace llvm::dxil; + +TEST(PointerTypeAnalysis, DigressToi8) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + store i32 0, ptr %p + %v = load i64, ptr %p + ret i64 %v + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 2u); + Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0); + Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false); + + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I8Ptr); + else + FAIL(); + } +} + +TEST(PointerTypeAnalysis, DiscoverStore) { + StringRef Assembly = R"( + define i32 @test(ptr %p) { + store i32 0, ptr %p + ret i32 0 + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 2u); + Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0); + Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false); + + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I32Ptr); + else + FAIL(); + } +} + +TEST(PointerTypeAnalysis, DiscoverLoad) { + StringRef Assembly = R"( + define i32 @test(ptr %p) { + %v = load i32, ptr %p + ret i32 %v + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 2u); + Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0); + Type *FnTy = FunctionType::get(Type::getInt32Ty(Context), {I32Ptr}, false); + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I32Ptr); + else + FAIL(); + } +} + +TEST(PointerTypeAnalysis, DiscoverGEP) { + StringRef Assembly = R"( + define ptr @test(ptr %p) { + %p2 = getelementptr i64, ptr %p, i64 1 + ret ptr %p2 + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 3u); + + Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0); + Type *FnTy = FunctionType::get(I64Ptr, {I64Ptr}, false); + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I64Ptr); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I64Ptr); + else + FAIL(); + } +} + +TEST(PointerTypeAnalysis, TraceIndirect) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + %p2 = load ptr, ptr %p + %v = load i64, ptr %p2 + ret i64 %v + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 3u); + + Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0); + Type *I64PtrPtr = TypedPointerType::get(I64Ptr, 0); + Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I64PtrPtr}, false); + + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I64PtrPtr); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I64Ptr); + else + FAIL(); + } +} + +TEST(PointerTypeAnalysis, WithNoOpCasts) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + %1 = bitcast ptr %p to ptr + %2 = bitcast ptr %p to ptr + store i32 0, ptr %1, align 4 + %3 = load i64, ptr %2, align 8 + ret i64 %3 + } + )"; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 4u); + + Type *I8Ptr = TypedPointerType::get(Type::getInt8Ty(Context), 0); + Type *I32Ptr = TypedPointerType::get(Type::getInt32Ty(Context), 0); + Type *I64Ptr = TypedPointerType::get(Type::getInt64Ty(Context), 0); + Type *FnTy = FunctionType::get(Type::getInt64Ty(Context), {I8Ptr}, false); + + for (auto &Entry : Map) { + if (isa(Entry.first)) + EXPECT_EQ(Entry.second, FnTy); + else if (isa(Entry.first)) + EXPECT_EQ(Entry.second, I8Ptr); + else if (isa(Entry.first)) { + const User *U = *(Entry.first->user_begin()); + if (isa(U)) + EXPECT_EQ(Entry.second, I64Ptr); + else if (isa(U)) + EXPECT_EQ(Entry.second, I32Ptr); + else + FAIL(); + } else + FAIL(); + } +}