diff --git a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/CRunnerUtils.h @@ -33,12 +33,6 @@ #include -template -void dropFront(int64_t arr[N], int64_t *res) { - for (unsigned i = 1; i < N; ++i) - *(res + i - 1) = arr[i]; -} - //===----------------------------------------------------------------------===// // Codegen-compatible structures for Vector type. //===----------------------------------------------------------------------===// @@ -129,6 +123,10 @@ res.basePtr = basePtr; res.data = data; res.offset = offset + idx * strides[0]; + auto dropFront = [](const int64_t *arr, int64_t *res) { + for (unsigned i = 1; i < N; ++i) + res[i - 1] = arr[i]; + }; dropFront(sizes, res.sizes); dropFront(strides, res.strides); return res; @@ -165,6 +163,39 @@ }; //===----------------------------------------------------------------------===// +// DynamicMemRefType type. +//===----------------------------------------------------------------------===// +// A reference to one of the StridedMemRef types. +template +class DynamicMemRefType { +public: + explicit DynamicMemRefType(const StridedMemRefType &mem_ref) + : rank(0), basePtr(mem_ref.basePtr), data(mem_ref.data), + offset(mem_ref.offset), sizes(nullptr), strides(nullptr) {} + template + explicit DynamicMemRefType(const StridedMemRefType &mem_ref) + : rank(N), basePtr(mem_ref.basePtr), data(mem_ref.data), + offset(mem_ref.offset), sizes(mem_ref.sizes), strides(mem_ref.strides) { + } + explicit DynamicMemRefType(const UnrankedMemRefType &mem_ref) + : rank(mem_ref.rank) { + auto *desc = static_cast *>(mem_ref.descriptor); + basePtr = desc->basePtr; + data = desc->data; + offset = desc->offset; + sizes = rank == 0 ? nullptr : desc->sizes; + strides = sizes + rank; + } + + int64_t rank; + T *basePtr; + T *data; + int64_t offset; + const int64_t *sizes; + const int64_t *strides; +}; + +//===----------------------------------------------------------------------===// // Small runtime support "lib" for vector.print lowering during codegen. //===----------------------------------------------------------------------===// extern "C" MLIR_CRUNNERUTILS_EXPORT void print_i1(bool b); diff --git a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h --- a/mlir/include/mlir/ExecutionEngine/RunnerUtils.h +++ b/mlir/include/mlir/ExecutionEngine/RunnerUtils.h @@ -35,29 +35,35 @@ #include "mlir/ExecutionEngine/CRunnerUtils.h" -template -void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { - static_assert(N > 0, "Expected N > 0"); - os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = " << N - << " offset = " << V.offset << " sizes = [" << V.sizes[0]; - for (unsigned i = 1; i < N; ++i) - os << ", " << V.sizes[i]; - os << "] strides = [" << V.strides[0]; - for (unsigned i = 1; i < N; ++i) - os << ", " << V.strides[i]; +template +void printMemRefMetaData(StreamType &os, const DynamicMemRefType &V) { + os << "base@ = " << reinterpret_cast(V.data) << " rank = " << V.rank + << " offset = " << V.offset; + auto print = [&](const int64_t *ptr) { + if (V.rank == 0) + return; + os << ptr[0]; + for (int64_t i = 1; i < V.rank; ++i) + os << ", " << ptr[i]; + }; + os << " sizes = ["; + print(V.sizes); + os << "] strides = ["; + print(V.strides); os << "]"; } -template -void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { - os << "Memref base@ = " << reinterpret_cast(V.data) << " rank = 0" - << " offset = " << V.offset; +template +void printMemRefMetaData(StreamType &os, StridedMemRefType &V) { + static_assert(N >= 0, "Expected N > 0"); + os << "MemRef "; + printMemRefMetaData(os, DynamicMemRefType(V)); } -template +template void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType &V) { - os << "Unranked Memref rank = " << V.rank << " " - << "descriptor@ = " << reinterpret_cast(V.descriptor) << "\n"; + os << "Unranked MemRef "; + printMemRefMetaData(os, DynamicMemRefType(V)); } //////////////////////////////////////////////////////////////////////////////// @@ -118,88 +124,92 @@ return os; } -template struct MemRefDataPrinter { - static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides); - static void printFirst(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, int64_t *strides); - static void printLast(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides); -}; - -template struct MemRefDataPrinter { - static void print(std::ostream &os, T *base, int64_t rank, int64_t offset, - int64_t *sizes = nullptr, int64_t *strides = nullptr); +template +struct MemRefDataPrinter { + static void print(std::ostream &os, T *base, int64_t dim, int64_t rank, + int64_t offset, const int64_t *sizes, + const int64_t *strides); + static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank, + int64_t offset, const int64_t *sizes, + const int64_t *strides); + static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank, + int64_t offset, const int64_t *sizes, + const int64_t *strides); }; -template -void MemRefDataPrinter::printFirst(std::ostream &os, T *base, - int64_t rank, int64_t offset, - int64_t *sizes, int64_t *strides) { +template +void MemRefDataPrinter::printFirst(std::ostream &os, T *base, int64_t dim, + int64_t rank, int64_t offset, + const int64_t *sizes, + const int64_t *strides) { os << "["; - MemRefDataPrinter::print(os, base, rank, offset, sizes + 1, - strides + 1); + print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1); // If single element, close square bracket and return early. if (sizes[0] <= 1) { os << "]"; return; } os << ", "; - if (N > 1) + if (dim > 1) os << "\n"; } -template -void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - printFirst(os, base, rank, offset, sizes, strides); +template +void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t dim, + int64_t rank, int64_t offset, + const int64_t *sizes, const int64_t *strides) { + if (dim == 0) { + os << base[offset]; + return; + } + printFirst(os, base, dim, rank, offset, sizes, strides); for (unsigned i = 1; i + 1 < sizes[0]; ++i) { - printSpace(os, rank - N + 1); - MemRefDataPrinter::print(os, base, rank, offset + i * strides[0], - sizes + 1, strides + 1); + printSpace(os, rank - dim + 1); + print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1, + strides + 1); os << ", "; - if (N > 1) + if (dim > 1) os << "\n"; } if (sizes[0] <= 1) return; - printLast(os, base, rank, offset, sizes, strides); + printLast(os, base, dim, rank, offset, sizes, strides); } -template -void MemRefDataPrinter::printLast(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - printSpace(os, rank - N + 1); - MemRefDataPrinter::print(os, base, rank, - offset + (sizes[0] - 1) * (*strides), - sizes + 1, strides + 1); +template +void MemRefDataPrinter::printLast(std::ostream &os, T *base, int64_t dim, + int64_t rank, int64_t offset, + const int64_t *sizes, + const int64_t *strides) { + printSpace(os, rank - dim + 1); + print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides), + sizes + 1, strides + 1); os << "]"; } template -void MemRefDataPrinter::print(std::ostream &os, T *base, int64_t rank, - int64_t offset, int64_t *sizes, - int64_t *strides) { - os << base[offset]; -} - -template void printMemRef(StridedMemRefType &M) { - static_assert(N > 0, "Expected N > 0"); +void printMemRef(const DynamicMemRefType &M) { printMemRefMetaData(std::cout, M); std::cout << " data = " << std::endl; - MemRefDataPrinter::print(std::cout, M.data, N, M.offset, M.sizes, - M.strides); + if (M.rank == 0) + std::cout << "["; + MemRefDataPrinter::print(std::cout, M.data, M.rank, M.rank, M.offset, + M.sizes, M.strides); + if (M.rank == 0) + std::cout << "]"; std::cout << std::endl; } -template void printMemRef(StridedMemRefType &M) { - printMemRefMetaData(std::cout, M); - std::cout << " data = " << std::endl; - std::cout << "["; - MemRefDataPrinter::print(std::cout, M.data, 0, M.offset); - std::cout << "]" << std::endl; +template +void printMemRef(StridedMemRefType &M) { + std::cout << "Memref "; + printMemRef(DynamicMemRefType(M)); +} + +template +void printMemRef(UnrankedMemRefType &M) { + std::cout << "Unranked Memref "; + printMemRef(DynamicMemRefType(M)); } } // namespace impl diff --git a/mlir/lib/ExecutionEngine/RunnerUtils.cpp b/mlir/lib/ExecutionEngine/RunnerUtils.cpp --- a/mlir/lib/ExecutionEngine/RunnerUtils.cpp +++ b/mlir/lib/ExecutionEngine/RunnerUtils.cpp @@ -24,57 +24,16 @@ impl::printMemRef(*M); } -#define MEMREF_CASE(TYPE, RANK) \ - case RANK: \ - impl::printMemRef(*(static_cast *>(ptr))); \ - break - extern "C" void _mlir_ciface_print_memref_i8(UnrankedMemRefType *M) { - printUnrankedMemRefMetaData(std::cout, *M); - int64_t rank = M->rank; - void *ptr = M->descriptor; - - switch (rank) { - MEMREF_CASE(int8_t, 0); - MEMREF_CASE(int8_t, 1); - MEMREF_CASE(int8_t, 2); - MEMREF_CASE(int8_t, 3); - MEMREF_CASE(int8_t, 4); - default: - assert(0 && "Unsupported rank to print"); - } + impl::printMemRef(*M); } extern "C" void _mlir_ciface_print_memref_i32(UnrankedMemRefType *M) { - printUnrankedMemRefMetaData(std::cout, *M); - int64_t rank = M->rank; - void *ptr = M->descriptor; - - switch (rank) { - MEMREF_CASE(int32_t, 0); - MEMREF_CASE(int32_t, 1); - MEMREF_CASE(int32_t, 2); - MEMREF_CASE(int32_t, 3); - MEMREF_CASE(int32_t, 4); - default: - assert(0 && "Unsupported rank to print"); - } + impl::printMemRef(*M); } extern "C" void _mlir_ciface_print_memref_f32(UnrankedMemRefType *M) { - printUnrankedMemRefMetaData(std::cout, *M); - int64_t rank = M->rank; - void *ptr = M->descriptor; - - switch (rank) { - MEMREF_CASE(float, 0); - MEMREF_CASE(float, 1); - MEMREF_CASE(float, 2); - MEMREF_CASE(float, 3); - MEMREF_CASE(float, 4); - default: - assert(0 && "Unsupported rank to print"); - } + impl::printMemRef(*M); } extern "C" void print_memref_i32(int64_t rank, void *ptr) { diff --git a/mlir/test/mlir-cpu-runner/unranked_memref.mlir b/mlir/test/mlir-cpu-runner/unranked_memref.mlir --- a/mlir/test/mlir-cpu-runner/unranked_memref.mlir +++ b/mlir/test/mlir-cpu-runner/unranked_memref.mlir @@ -1,25 +1,21 @@ // RUN: mlir-opt %s -convert-linalg-to-loops -convert-linalg-to-llvm -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%linalg_test_lib_dir/libmlir_runner_utils%shlibext,%linalg_test_lib_dir/libmlir_test_cblas%shlibext,%linalg_test_lib_dir/libmlir_test_cblas_interface%shlibext | FileCheck %s // CHECK: rank = 2 -// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [10, 10, 10] // // CHECK: rank = 2 -// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [5, 5, 5] // // CHECK: rank = 2 -// CHECK: rank = 2 // CHECK-SAME: sizes = [10, 3] // CHECK-SAME: strides = [3, 1] // CHECK-COUNT-10: [2, 2, 2] // // CHECK: rank = 0 -// CHECK: rank = 0 // 122 is ASCII for 'z'. // CHECK: [z] func @main() -> () { diff --git a/mlir/test/mlir-cpu-runner/utils.mlir b/mlir/test/mlir-cpu-runner/utils.mlir --- a/mlir/test/mlir-cpu-runner/utils.mlir +++ b/mlir/test/mlir-cpu-runner/utils.mlir @@ -12,8 +12,7 @@ dealloc %A : memref return } -// PRINT-0D: Unranked Memref rank = 0 descriptor@ = {{.*}} -// PRINT-0D: Memref base@ = {{.*}} rank = 0 offset = 0 data = +// PRINT-0D: Unranked Memref base@ = {{.*}} rank = 0 offset = 0 sizes = [] strides = [] data = // PRINT-0D: [2] func @print_1d() { @@ -26,7 +25,7 @@ dealloc %A : memref<16xf32> return } -// PRINT-1D: Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data = +// PRINT-1D: Unranked Memref base@ = {{.*}} rank = 1 offset = 0 sizes = [16] strides = [1] data = // PRINT-1D-NEXT: [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] func @print_3d() { @@ -43,7 +42,7 @@ dealloc %A : memref<3x4x5xf32> return } -// PRINT-3D: Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data = +// PRINT-3D: Unranked Memref base@ = {{.*}} rank = 3 offset = 0 sizes = [3, 4, 5] strides = [20, 5, 1] data = // PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2 // PRINT-3D-COUNT-4: {{.*[[:space:]].*}}2, 2, 2, 2, 2 // PRINT-3D-COUNT-2: {{.*[[:space:]].*}}2, 2, 2, 2, 2 diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -83,10 +83,10 @@ // Allows to register a MemRef with the CUDA runtime. Initializes array with // value. Helpful until we have transfer functions implemented. template -void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef sizes, - llvm::ArrayRef strides, T value) { - assert(sizes.size() == strides.size()); - llvm::SmallVector denseStrides(strides.size()); +void mcuMemHostRegisterMemRef(const DynamicMemRefType &mem_ref, T value) { + llvm::SmallVector denseStrides(mem_ref.rank); + llvm::ArrayRef sizes(mem_ref.sizes, mem_ref.rank); + llvm::ArrayRef strides(mem_ref.strides, mem_ref.rank); std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), std::multiplies()); @@ -98,20 +98,17 @@ denseStrides.back() = 1; assert(strides == llvm::makeArrayRef(denseStrides)); + auto *pointer = mem_ref.data + mem_ref.offset; std::fill_n(pointer, count, value); mgpuMemHostRegister(pointer, count * sizeof(T)); } extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) { - auto *desc = static_cast *>(ptr); - auto sizes = llvm::ArrayRef(desc->sizes, rank); - auto strides = llvm::ArrayRef(desc->sizes + rank, rank); - mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f); + UnrankedMemRefType mem_ref = {rank, ptr}; + mcuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 1.23f); } extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) { - auto *desc = static_cast *>(ptr); - auto sizes = llvm::ArrayRef(desc->sizes, rank); - auto strides = llvm::ArrayRef(desc->sizes + rank, rank); - mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123); + UnrankedMemRefType mem_ref = {rank, ptr}; + mcuMemHostRegisterMemRef(DynamicMemRefType(mem_ref), 123); }