Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/Device.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/Device.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/Device.h @@ -15,13 +15,14 @@ #ifndef STREAMEXECUTOR_DEVICE_H #define STREAMEXECUTOR_DEVICE_H +#include + #include "streamexecutor/KernelSpec.h" #include "streamexecutor/PlatformInterfaces.h" #include "streamexecutor/Utils/Error.h" namespace streamexecutor { -class KernelInterface; class Stream; class Device { @@ -29,11 +30,24 @@ explicit Device(PlatformDevice *PDevice); virtual ~Device(); - /// Gets the kernel implementation for the underlying platform. - virtual Expected> - getKernelImplementation(const MultiKernelLoaderSpec &Spec) { - // TODO(jhen): Implement this. - return nullptr; + /// Creates a kernel object for this device. + /// + /// If the return value is not an error, the returned pointer will never be + /// null. + /// + /// See \ref CompilerGeneratedKernelExample "Kernel.h" for an example of how + /// this method is used. + template + Expected::value, KernelT>::type>> + createKernel(const MultiKernelLoaderSpec &Spec) { + Expected> MaybeKernelHandle = + PDevice->createKernel(Spec); + if (!MaybeKernelHandle) { + return MaybeKernelHandle.takeError(); + } + return llvm::make_unique(Spec.getKernelName(), + std::move(*MaybeKernelHandle)); } Expected> createStream(); Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/Kernel.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/Kernel.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/Kernel.h @@ -11,62 +11,64 @@ /// Types to represent device kernels (code compiled to run on GPU or other /// accelerator). /// -/// The TypedKernel class is used to provide type safety to the user API's -/// launch functions, and the KernelBase class is used like a void* function -/// pointer to perform type-unsafe operations inside StreamExecutor. -/// -/// With the kernel parameter types recorded in the TypedKernel template -/// parameters, type-safe kernel launch functions can be written with signatures -/// like the following: +/// With the kernel parameter types recorded in the Kernel template parameters, +/// type-safe kernel launch functions can be written with signatures like the +/// following: /// \code /// template /// void Launch( -/// const TypedKernel &Kernel, ParamterTs... Arguments); +/// const Kernel &Kernel, ParamterTs... Arguments); /// \endcode /// and the compiler will check that the user passes in arguments with types /// matching the corresponding kernel parameters. /// -/// A problem is that a TypedKernel template specialization with the right -/// parameter types must be passed as the first argument to the Launch function, -/// and it's just as hard to get the types right in that template specialization -/// as it is to get them right for the kernel arguments. +/// A problem is that a Kernel template specialization with the right parameter +/// types must be passed as the first argument to the Launch function, and it's +/// just as hard to get the types right in that template specialization as it is +/// to get them right for the kernel arguments. /// /// With this problem in mind, it is not recommended for users to specialize the -/// TypedKernel template class themselves, but instead to let the compiler do it -/// for them. When the compiler encounters a device kernel function, it can -/// create a TypedKernel template specialization in the host code that has the -/// right parameter types for that kernel and which has a type name based on the -/// name of the kernel function. +/// Kernel template class themselves, but instead to let the compiler do it for +/// them. When the compiler encounters a device kernel function, it can create a +/// Kernel template specialization in the host code that has the right parameter +/// types for that kernel and which has a type name based on the name of the +/// kernel function. /// +/// \anchor CompilerGeneratedKernelExample /// For example, if a CUDA device kernel function with the following signature /// has been defined: /// \code -/// void Saxpy(float *A, float *X, float *Y); +/// void Saxpy(float A, float *X, float *Y); /// \endcode /// the compiler can insert the following declaration in the host code: /// \code /// namespace compiler_cuda_namespace { +/// namespace se = streamexecutor; /// using SaxpyKernel = -/// streamexecutor::TypedKernel; +/// se::Kernel< +/// float, +/// se::GlobalDeviceMemory, +/// se::GlobalDeviceMemory>; /// } // namespace compiler_cuda_namespace /// \endcode /// and then the user can launch the kernel by calling the StreamExecutor launch /// function as follows: /// \code /// namespace ccn = compiler_cuda_namespace; +/// using KernelPtr = std::unique_ptr; /// // Assumes Device is a pointer to the Device on which to launch the /// // kernel. /// // /// // See KernelSpec.h for details on how the compiler can create a /// // MultiKernelLoaderSpec instance like SaxpyKernelLoaderSpec below. -/// Expected MaybeKernel = -/// ccn::SaxpyKernel::create(Device, ccn::SaxpyKernelLoaderSpec); +/// Expected MaybeKernel = +/// Device->createKernel(ccn::SaxpyKernelLoaderSpec); /// if (!MaybeKernel) { /* Handle error */ } -/// ccn::SaxpyKernel SaxpyKernel = *MaybeKernel; -/// Launch(SaxpyKernel, A, X, Y); +/// KernelPtr SaxpyKernel = std::move(*MaybeKernel); +/// Launch(*SaxpyKernel, A, X, Y); /// \endcode /// -/// With the compiler's help in specializing TypedKernel for each device kernel +/// With the compiler's help in specializing Kernel for each device kernel /// function (and generating a MultiKernelLoaderSpec instance for each kernel), /// the user can safely launch the device kernel from the host and get an error /// message at compile time if the argument types don't match the kernel @@ -84,73 +86,37 @@ namespace streamexecutor { -class Device; -class KernelInterface; +class PlatformKernelHandle; -/// The base class for device kernel functions. +/// The base class for all kernel types. /// -/// This class has no information about the types of the parameters taken by the -/// kernel, so it is analogous to a void* pointer to a device function. -/// -/// See the TypedKernel class below for the subclass which does have information -/// about parameter types. +/// Stores the name of the kernel in both mangled and demangled forms. class KernelBase { public: - KernelBase(KernelBase &&) = default; - KernelBase &operator=(KernelBase &&) = default; - ~KernelBase(); - - /// Creates a kernel object from a Device and a MultiKernelLoaderSpec. - /// - /// The Device knows which platform it belongs to and the - /// MultiKernelLoaderSpec knows how to find the kernel code for different - /// platforms, so the combined information is enough to get the kernel code - /// for the appropriate platform. - static Expected create(Device *Dev, - const MultiKernelLoaderSpec &Spec); + KernelBase(llvm::StringRef Name); const std::string &getName() const { return Name; } const std::string &getDemangledName() const { return DemangledName; } - /// Gets a pointer to the platform-specific implementation of this kernel. - KernelInterface *getImplementation() { return Implementation.get(); } - private: - KernelBase(Device *Dev, const std::string &Name, - const std::string &DemangledName, - std::unique_ptr Implementation); - - Device *TheDevice; std::string Name; std::string DemangledName; - std::unique_ptr Implementation; - - KernelBase(const KernelBase &) = delete; - KernelBase &operator=(const KernelBase &) = delete; }; -/// A device kernel function with specified parameter types. -template class TypedKernel : public KernelBase { +/// A StreamExecutor kernel. +/// +/// The template parameters are the types of the parameters to the kernel +/// function. +template class Kernel : public KernelBase { public: - TypedKernel(TypedKernel &&) = default; - TypedKernel &operator=(TypedKernel &&) = default; + Kernel(llvm::StringRef Name, std::unique_ptr PHandle) + : KernelBase(Name), PHandle(std::move(PHandle)) {} - /// Parameters here have the same meaning as in KernelBase::create. - static Expected create(Device *Dev, - const MultiKernelLoaderSpec &Spec) { - auto MaybeBase = KernelBase::create(Dev, Spec); - if (!MaybeBase) { - return MaybeBase.takeError(); - } - TypedKernel Instance(std::move(*MaybeBase)); - return std::move(Instance); - } + /// Gets the underlying platform-specific handle for this kernel. + PlatformKernelHandle *getPlatformHandle() const { return PHandle.get(); } private: - TypedKernel(KernelBase &&Base) : KernelBase(std::move(Base)) {} - - TypedKernel(const TypedKernel &) = delete; - TypedKernel &operator=(const TypedKernel &) = delete; + std::unique_ptr PHandle; }; } // namespace streamexecutor Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/PlatformInterfaces.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/PlatformInterfaces.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/PlatformInterfaces.h @@ -33,9 +33,17 @@ class PlatformDevice; -/// Methods supported by device kernel function objects on all platforms. -class KernelInterface { - // TODO(jhen): Add methods. +/// Platform-specific kernel handle. +class PlatformKernelHandle { +public: + explicit PlatformKernelHandle(PlatformDevice *PDevice) : PDevice(PDevice) {} + + virtual ~PlatformKernelHandle(); + + PlatformDevice *getDevice() { return PDevice; } + +private: + PlatformDevice *PDevice; }; /// Platform-specific stream handle. @@ -64,12 +72,20 @@ virtual std::string getName() const = 0; + /// Creates a platform-specific kernel. + virtual Expected> + createKernel(const MultiKernelLoaderSpec &Spec) { + return make_error("createKernel not implemented for platform " + getName()); + } + /// Creates a platform-specific stream. - virtual Expected> createStream() = 0; + virtual Expected> createStream() { + return make_error("createStream not implemented for platform " + getName()); + } /// Launches a kernel on the given stream. virtual Error launch(PlatformStreamHandle *S, BlockDimensions BlockSize, - GridDimensions GridSize, const KernelBase &Kernel, + GridDimensions GridSize, PlatformKernelHandle *K, const PackedKernelArgumentArrayBase &ArgumentArray) { return make_error("launch not implemented for platform " + getName()); } Index: parallel-libs/trunk/streamexecutor/include/streamexecutor/Stream.h =================================================================== --- parallel-libs/trunk/streamexecutor/include/streamexecutor/Stream.h +++ parallel-libs/trunk/streamexecutor/include/streamexecutor/Stream.h @@ -86,15 +86,15 @@ /// These arguments can be device memory types like GlobalDeviceMemory and /// SharedDeviceMemory, or they can be primitive types such as int. The /// allowable argument types are determined by the template parameters to the - /// TypedKernel argument. + /// Kernel argument. template Stream &thenLaunch(BlockDimensions BlockSize, GridDimensions GridSize, - const TypedKernel &Kernel, + const Kernel &K, const ParameterTs &... Arguments) { auto ArgumentArray = make_kernel_argument_pack(Arguments...); setError(PDevice->launch(ThePlatformStream.get(), BlockSize, GridSize, - Kernel, ArgumentArray)); + K.getPlatformHandle(), ArgumentArray)); return *this; } Index: parallel-libs/trunk/streamexecutor/lib/Kernel.cpp =================================================================== --- parallel-libs/trunk/streamexecutor/lib/Kernel.cpp +++ parallel-libs/trunk/streamexecutor/lib/Kernel.cpp @@ -20,26 +20,8 @@ namespace streamexecutor { -KernelBase::KernelBase(Device *Dev, const std::string &Name, - const std::string &DemangledName, - std::unique_ptr Implementation) - : TheDevice(Dev), Name(Name), DemangledName(DemangledName), - Implementation(std::move(Implementation)) {} - -KernelBase::~KernelBase() = default; - -Expected KernelBase::create(Device *Dev, - const MultiKernelLoaderSpec &Spec) { - auto MaybeImplementation = Dev->getKernelImplementation(Spec); - if (!MaybeImplementation) { - return MaybeImplementation.takeError(); - } - std::string Name = Spec.getKernelName(); - std::string DemangledName = - llvm::symbolize::LLVMSymbolizer::DemangleName(Name, nullptr); - KernelBase Instance(Dev, Name, DemangledName, - std::move(*MaybeImplementation)); - return std::move(Instance); -} +KernelBase::KernelBase(llvm::StringRef Name) + : Name(Name), DemangledName(llvm::symbolize::LLVMSymbolizer::DemangleName( + Name, nullptr)) {} } // namespace streamexecutor Index: parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt =================================================================== --- parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt +++ parallel-libs/trunk/streamexecutor/lib/unittests/CMakeLists.txt @@ -9,16 +9,6 @@ add_test(DeviceTest device_test) add_executable( - kernel_test - KernelTest.cpp) -target_link_libraries( - kernel_test - streamexecutor - ${GTEST_BOTH_LIBRARIES} - ${CMAKE_THREAD_LIBS_INIT}) -add_test(KernelTest kernel_test) - -add_executable( kernel_spec_test KernelSpecTest.cpp) target_link_libraries( Index: parallel-libs/trunk/streamexecutor/lib/unittests/KernelTest.cpp =================================================================== --- parallel-libs/trunk/streamexecutor/lib/unittests/KernelTest.cpp +++ parallel-libs/trunk/streamexecutor/lib/unittests/KernelTest.cpp @@ -1,93 +0,0 @@ -//===-- KernelTest.cpp - Tests for Kernel objects -------------------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -/// -/// \file -/// This file contains the unit tests for the code in Kernel. -/// -//===----------------------------------------------------------------------===// - -#include - -#include "streamexecutor/Device.h" -#include "streamexecutor/Kernel.h" -#include "streamexecutor/KernelSpec.h" -#include "streamexecutor/PlatformInterfaces.h" - -#include "llvm/ADT/STLExtras.h" - -#include "gtest/gtest.h" - -namespace { - -namespace se = ::streamexecutor; - -// A Device that returns a dummy KernelInterface. -// -// During construction it creates a unique_ptr to a dummy KernelInterface and it -// also stores a separate copy of the raw pointer that is stored by that -// unique_ptr. -// -// The expectation is that the code being tested will call the -// getKernelImplementation method and will thereby take ownership of the -// unique_ptr, but the copy of the raw pointer will stay behind in this mock -// object. The raw pointer copy can then be used to identify the unique_ptr in -// its new location (by comparing the raw pointer with unique_ptr::get), to -// verify that the unique_ptr ended up where it was supposed to be. -class MockDevice : public se::Device { -public: - MockDevice() - : se::Device(nullptr), Unique(llvm::make_unique()), - Raw(Unique.get()) {} - - // Moves the unique pointer into the returned se::Expected instance. - // - // Asserts that it is not called again after the unique pointer has been moved - // out. - se::Expected> - getKernelImplementation(const se::MultiKernelLoaderSpec &) override { - assert(Unique && "MockDevice getKernelImplementation should not be " - "called more than once"); - return std::move(Unique); - } - - // Gets the copy of the raw pointer from the original unique pointer. - const se::KernelInterface *getRaw() const { return Raw; } - -private: - std::unique_ptr Unique; - const se::KernelInterface *Raw; -}; - -// Test fixture class for typed tests for KernelBase.getImplementation. -// -// The only purpose of this class is to provide a name that types can be bound -// to in the gtest infrastructure. -template class GetImplementationTest : public ::testing::Test {}; - -// Types used with the GetImplementationTest fixture class. -typedef ::testing::Types, - se::TypedKernel> - GetImplementationTypes; - -TYPED_TEST_CASE(GetImplementationTest, GetImplementationTypes); - -// Tests that the kernel create functions properly fetch the implementation -// pointers for the kernel objects they construct from the passed-in -// Device objects. -TYPED_TEST(GetImplementationTest, SetImplementationDuringCreate) { - se::MultiKernelLoaderSpec Spec; - MockDevice Dev; - - auto MaybeKernel = TypeParam::create(&Dev, Spec); - EXPECT_TRUE(static_cast(MaybeKernel)); - se::KernelInterface *Implementation = MaybeKernel->getImplementation(); - EXPECT_EQ(Dev.getRaw(), Implementation); -} - -} // namespace