diff --git a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp --- a/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp +++ b/mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp @@ -12,11 +12,7 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "mlir/ExecutionEngine/CRunnerUtils.h" -#include "llvm/ADT/ArrayRef.h" #include "cuda.h" @@ -160,26 +156,25 @@ CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); } -// Allows to register a MemRef with the CUDA runtime. Helpful until we have -// transfer functions implemented. +/// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a +/// ranked memref descriptor struct of rank `rank`. Helpful until we have +/// transfer functions implemented. extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType *descriptor, int64_t elementSizeBytes) { - - llvm::SmallVector denseStrides(rank); - llvm::ArrayRef sizes(descriptor->sizes, rank); - llvm::ArrayRef strides(sizes.end(), rank); - - std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), - std::multiplies()); - auto sizeBytes = denseStrides.front() * elementSizeBytes; - // Only densely packed tensors are currently supported. - std::rotate(denseStrides.begin(), denseStrides.begin() + 1, - denseStrides.end()); - denseStrides.back() = 1; - assert(strides == llvm::makeArrayRef(denseStrides)); + int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t)); + int64_t *sizes = descriptor->sizes; + for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) { + denseStrides[i] = runningStride; + runningStride *= sizes[i]; + } + uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes; + int64_t *strides = &sizes[rank]; + for (unsigned i = 0; i < rank; ++i) + assert(strides[i] == denseStrides[i] && + "Mismatch in computed dense strides"); - auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; + auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; mgpuMemHostRegister(ptr, sizeBytes); }