diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -641,7 +641,8 @@ /// This constructs a pointer to an object of the specified type in a numbered /// address space. - static PointerType *get(Type *ElementType, unsigned AddressSpace); + static PointerType *get(Type *ElementType, unsigned AddressSpace, + bool ForceTyped = false); /// This constructs an opaque pointer to an object in a numbered address /// space. static PointerType *get(LLVMContext &C, unsigned AddressSpace); diff --git a/llvm/lib/IR/Type.cpp b/llvm/lib/IR/Type.cpp --- a/llvm/lib/IR/Type.cpp +++ b/llvm/lib/IR/Type.cpp @@ -724,14 +724,15 @@ // PointerType Implementation //===----------------------------------------------------------------------===// -PointerType *PointerType::get(Type *EltTy, unsigned AddressSpace) { +PointerType *PointerType::get(Type *EltTy, unsigned AddressSpace, + bool ForceTyped) { assert(EltTy && "Can't get a pointer to type!"); assert(isValidElementType(EltTy) && "Invalid type for pointer element!"); LLVMContextImpl *CImpl = EltTy->getContext().pImpl; // Automatically convert typed pointers to opaque pointers. - if (CImpl->getOpaquePointers()) + if (CImpl->getOpaquePointers() && !ForceTyped) return get(EltTy->getContext(), AddressSpace); // Since AddressSpace #0 is the common case, we special case it. diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -11,6 +11,7 @@ DirectXTargetMachine.cpp DXILBitcodeWriter.cpp DXILPrepare.cpp + PointerTypeAnalysis.cpp LINK_COMPONENTS Bitwriter diff --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.h b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.h @@ -0,0 +1,38 @@ +//===- Target/DirectX/PointerTypeAnalisis.h - PointerType analysis --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Analysis pass to assign types to opaque pointers. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H +#define LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +// Store the underlying type and the number of pointer indirections +using PointerTypeMap = DenseMap; + +/// An analysis to compute the \c PointerTypes for pointers in a \c Module. +/// Since this analysis is only run during codegen and the new pass manager +/// doesn't support codegen passes, this is wrtten as a function in a namespace. +/// It is very simple to transform it into a proper analysis pass. +/// This code relies on typed pointers existing as LLVM types, but could be +/// migrated to a custom Type if PointerType loses typed support. +namespace PointerTypeAnalysis { + +/// Compute the \c PointerTypeMap for the module \c M. +PointerTypeMap run(const Module &M); +} // namespace PointerTypeAnalysis + +} // namespace llvm + +#endif // LLVM_TARGET_DIRECTX_POINTERTYPEANALYSIS_H diff --git a/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Target/DirectX/PointerTypeAnalysis.cpp @@ -0,0 +1,76 @@ +//===- Target/DirectX/PointerTypeAnalisis.cpp - PointerType analysis ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Analysis pass to assign types to opaque pointers. +// +//===----------------------------------------------------------------------===// + +#include "PointerTypeAnalysis.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +namespace { +Type *classifyPointerType(const Value *V) { + assert(V->getType()->isOpaquePointerTy() && + "classifyPointerType called with non-opaque pointer"); + Type *PointeeTy = nullptr; + if (auto *Inst = dyn_cast(V)) { + PointeeTy = Inst->getResultElementType()->isOpaquePointerTy() + ? nullptr + : Inst->getResultElementType(); + } + for (auto &Use : V->uses()) { + Type *NewPointeeTy = nullptr; + if (auto *Inst = dyn_cast(Use.getUser())) { + NewPointeeTy = Inst->getType(); + } else if (auto *Inst = dyn_cast(Use.getUser())) { + NewPointeeTy = Inst->getValueOperand()->getType(); + } else if (auto *Inst = dyn_cast(Use.getUser())) { + NewPointeeTy = Inst->getSourceElementType(); + } else if (auto *Inst = dyn_cast(Use.getUser())) { + NewPointeeTy = Inst->getAllocatedType(); + } + if (NewPointeeTy) { + if (NewPointeeTy->isOpaquePointerTy()) + return PointerType::get(classifyPointerType(Use.getUser()), + V->getType()->getPointerAddressSpace(), true); + if (!PointeeTy) + PointeeTy = NewPointeeTy; + else if (PointeeTy != NewPointeeTy) + PointeeTy = Type::getInt8Ty(V->getContext()); + } + } + // If we were unable to determine the pointee type, set to i8 + if (!PointeeTy) + PointeeTy = Type::getInt8Ty(V->getContext()); + return PointerType::get(PointeeTy, V->getType()->getPointerAddressSpace(), + true); +} +} // anonymous namespace + +PointerTypeMap PointerTypeAnalysis::run(const Module &M) { + PointerTypeMap Map; + for (auto &G : M.globals()) { + if (G.getType()->isOpaquePointerTy()) + Map[&G] = classifyPointerType(&G); + } + for (auto &F : M) { + for (auto &A : F.args()) { + if (A.getType()->isOpaquePointerTy()) + Map[&A] = classifyPointerType(&A); + } + for (auto &B : F) { + for (auto &I : B) { + if (I.getType()->isOpaquePointerTy()) + Map[&I] = classifyPointerType(&I); + } + } + } + return Map; +} diff --git a/llvm/unittests/Target/DirectX/CMakeLists.txt b/llvm/unittests/Target/DirectX/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/DirectX/CMakeLists.txt @@ -0,0 +1,15 @@ +include_directories( + ${LLVM_MAIN_SRC_DIR}/lib/Target/DirectX + ${LLVM_BINARY_DIR}/lib/Target/DirectX + ) + +set(LLVM_LINK_COMPONENTS + AsmParser + Core + DirectXCodeGen + Support +) + +add_llvm_target_unittest(DirectXTests + PointerTypeAnalysisTests.cpp + ) diff --git a/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp @@ -0,0 +1,184 @@ +//===- llvm/unittests/Target/DirectX/PointerTypeAnalysisTests.cpp ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "PointerTypeAnalysis.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/SourceMgr.h" +#include "gtest/gtest.h" + +using namespace llvm; + +TEST(PointerTypeAnalysis, DigressToi8) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + store i32 0, ptr %p + %v = load i64, ptr %p + ret i64 %v + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 1u); + EXPECT_EQ(Map.begin()->second, + PointerType::get(Type::getInt8Ty(Context), 0, true)); +} + +TEST(PointerTypeAnalysis, DiscoverStore) { + StringRef Assembly = R"( + define i32 @test(ptr %p) { + store i32 0, ptr %p + ret i32 0 + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 1u); + EXPECT_EQ(Map.begin()->second, + PointerType::get(Type::getInt32Ty(Context), 0, true)); +} + +TEST(PointerTypeAnalysis, DiscoverLoad) { + StringRef Assembly = R"( + define i32 @test(ptr %p) { + %v = load i32, ptr %p + ret i32 %v + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 1u); + EXPECT_EQ(Map.begin()->second, + PointerType::get(Type::getInt32Ty(Context), 0, true)); +} + +TEST(PointerTypeAnalysis, DiscoverGEP) { + StringRef Assembly = R"( + define ptr @test(ptr %p) { + %p2 = getelementptr i64, ptr %p, i64 1 + ret ptr %p2 + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 2u); + + // The checking code here is a bit silly... I know what the two results should + // be, and they should match pointers. So I throw them into vectors, sort them + // then compare the vectors. + SmallVector Types; + for (auto El : Map) { + Types.push_back(El.second); + } + std::sort(Types.begin(), Types.end()); + + Type *I64Ptr = PointerType::get(Type::getInt64Ty(Context), 0, true); + SmallVector Results = {I64Ptr, I64Ptr}; + std::sort(Results.begin(), Results.end()); + + EXPECT_TRUE(memcmp(Types.data(), Results.data(), 2 * sizeof(Type *)) == 0); +} + +TEST(PointerTypeAnalysis, TraceIndirect) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + %p2 = load ptr, ptr %p + %v = load i64, ptr %p2 + ret i64 %v + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 2u); + + // The checking code here is a bit silly... I know what the two results should + // be, and they should match pointers. So I throw them into vectors, sort them + // then compare the vectors. + SmallVector Types; + for (auto El : Map) { + Types.push_back(El.second); + } + std::sort(Types.begin(), Types.end()); + + Type *I64Ptr = PointerType::get(Type::getInt64Ty(Context), 0, true); + Type *I64PtrPtr = PointerType::get(I64Ptr, 0, true); + SmallVector Results = {I64Ptr, I64PtrPtr}; + std::sort(Results.begin(), Results.end()); + + EXPECT_TRUE(memcmp(Types.data(), Results.data(), 2 * sizeof(Type *)) == 0); +} + +TEST(PointerTypeAnalysis, WithNoOpCasts) { + StringRef Assembly = R"( + define i64 @test(ptr %p) { + %1 = bitcast ptr %p to ptr + %2 = bitcast ptr %p to ptr + store i32 0, ptr %1, align 4 + %3 = load i64, ptr %2, align 8 + ret i64 %3 + } + )"; + + // Parse the IR. The two calls in @test can not access aliasing elements. + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(Assembly, Error, Context); + ASSERT_TRUE(M) << "Bad assembly?"; + + PointerTypeMap Map = PointerTypeAnalysis::run(*M); + ASSERT_EQ(Map.size(), 3u); + + // The checking code here is a bit silly... I know what the two results should + // be, and they should match pointers. So I throw them into vectors, sort them + // then compare the vectors. + SmallVector Types; + for (auto El : Map) { + Types.push_back(El.second); + } + std::sort(Types.begin(), Types.end()); + + Type *I8Ptr = PointerType::get(Type::getInt8Ty(Context), 0, true); + Type *I32Ptr = PointerType::get(Type::getInt32Ty(Context), 0, true); + Type *I64Ptr = PointerType::get(Type::getInt64Ty(Context), 0, true); + SmallVector Results = {I8Ptr, I32Ptr, I64Ptr}; + std::sort(Results.begin(), Results.end()); + + EXPECT_TRUE( + memcmp(Types.data(), Results.data(), Types.size() * sizeof(Type *)) == 0); +}