diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir @@ -2,9 +2,7 @@ func @main() { %data = alloc() : memref<2x6xi32> - %sum_and = alloc() : memref<2xi32> - %sum_or = alloc() : memref<2xi32> - %sum_min = alloc() : memref<2xi32> + %sum = alloc() : memref<2xi32> %cst0 = constant 0 : i32 %cst1 = constant 1 : i32 %cst2 = constant 2 : i32 @@ -25,7 +23,12 @@ %c4 = constant 4 : index %c5 = constant 5 : index %c6 = constant 6 : index - + + %cast_data = memref_cast %data : memref<2x6xi32> to memref + call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref + call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> store %cst2, %data[%c0, %c2] : memref<2x6xi32> @@ -44,17 +47,19 @@ gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %c2, %grid_y = %c1, %grid_z = %c1) threads(%tx, %ty, %tz) in (%block_x = %c6, %block_y = %c1, %block_z = %c1) { %val = load %data[%bx, %tx] : memref<2x6xi32> - %reduced_and = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32) - store %reduced_and, %sum_and[%bx] : memref<2xi32> + %reduced = "gpu.all_reduce"(%val) ({}) { op = "and" } : (i32) -> (i32) + store %reduced, %sum[%bx] : memref<2xi32> gpu.terminator } - %ptr_and = memref_cast %sum_and : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr_and) : (memref<*xi32>) -> () + %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @print_memref_i32(%ptr) : (memref<*xi32>) -> () // CHECK: [0, 2] return } +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir @@ -23,7 +23,12 @@ %c4 = constant 4 : index %c5 = constant 5 : index %c6 = constant 6 : index - + + %cast_data = memref_cast %data : memref<2x6xi32> to memref + call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref + call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> store %cst2, %data[%c0, %c2] : memref<2x6xi32> @@ -54,5 +59,7 @@ return } +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir @@ -23,7 +23,12 @@ %c4 = constant 4 : index %c5 = constant 5 : index %c6 = constant 6 : index - + + %cast_data = memref_cast %data : memref<2x6xi32> to memref + call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref + call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> store %cst2, %data[%c0, %c2] : memref<2x6xi32> @@ -54,5 +59,7 @@ return } +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir @@ -23,7 +23,12 @@ %c4 = constant 4 : index %c5 = constant 5 : index %c6 = constant 6 : index - + + %cast_data = memref_cast %data : memref<2x6xi32> to memref + call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref + call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> store %cst2, %data[%c0, %c2] : memref<2x6xi32> @@ -54,5 +59,7 @@ return } +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir @@ -23,7 +23,12 @@ %c4 = constant 4 : index %c5 = constant 5 : index %c6 = constant 6 : index - + + %cast_data = memref_cast %data : memref<2x6xi32> to memref + call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref + call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> store %cst2, %data[%c0, %c2] : memref<2x6xi32> @@ -54,5 +59,7 @@ return } +func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) +func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir --- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir @@ -25,6 +25,13 @@ %c5 = constant 5 : index %c6 = constant 6 : index + %cast_data = memref_cast %data : memref<2x6xf32> to memref + call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref) -> () + %cast_sum = memref_cast %sum : memref<2xf32> to memref + call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref) -> () + %cast_mul = memref_cast %mul : memref<2xf32> to memref + call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref) -> () + store %cst0, %data[%c0, %c0] : memref<2x6xf32> store %cst1, %data[%c0, %c1] : memref<2x6xf32> store %cst2, %data[%c0, %c2] : memref<2x6xf32> @@ -61,4 +68,6 @@ return } +func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) +func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -15,6 +15,7 @@ #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" #include "cuda.h" @@ -89,24 +90,39 @@ // Allows to register a MemRef with the CUDA runtime. Initializes array with // value. Helpful until we have transfer functions implemented. -template -void mcuMemHostRegisterMemRef(const MemRefType *arg, T value) { - auto count = std::accumulate(arg->sizes, arg->sizes + N, 1, - std::multiplies()); - std::fill_n(arg->data, count, value); - mcuMemHostRegister(arg->data, count * sizeof(T)); +template +void mcuMemHostRegisterMemRef(T *pointer, llvm::ArrayRef sizes, + llvm::ArrayRef strides, T value) { + assert(sizes.size() == strides.size()); + llvm::SmallVector denseStrides(strides.size()); + + std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), + std::multiplies()); + auto count = denseStrides.front(); + + // Only densely backed tensors are currently supported. + std::rotate(denseStrides.begin(), denseStrides.begin() + 1, + denseStrides.end()); + denseStrides.back() = 1; + assert(strides == llvm::makeArrayRef(denseStrides)); + + std::fill_n(pointer, count, value); + mcuMemHostRegister(pointer, count * sizeof(T)); } extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated, float *aligned, int64_t offset, int64_t size, int64_t stride) { - MemRefType descriptor; - descriptor.basePtr = allocated; - descriptor.data = aligned; - descriptor.offset = offset; - descriptor.sizes[0] = size; - descriptor.strides[0] = stride; - mcuMemHostRegisterMemRef(&descriptor, 1.23f); + mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f); +} + +extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated, + float *aligned, int64_t offset, + int64_t size0, int64_t size1, + int64_t stride0, + int64_t stride1) { + mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1}, + 1.23f); } extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated, @@ -115,15 +131,31 @@ int64_t size2, int64_t stride0, int64_t stride1, int64_t stride2) { - MemRefType descriptor; - descriptor.basePtr = allocated; - descriptor.data = aligned; - descriptor.offset = offset; - descriptor.sizes[0] = size0; - descriptor.strides[0] = stride0; - descriptor.sizes[1] = size1; - descriptor.strides[1] = stride1; - descriptor.sizes[2] = size2; - descriptor.strides[2] = stride2; - mcuMemHostRegisterMemRef(&descriptor, 1.23f); + mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2}, + {stride0, stride1, stride2}, 1.23f); +} + +extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated, + int32_t *aligned, + int64_t offset, int64_t size, + int64_t stride) { + mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123); +} + +extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated, + int32_t *aligned, + int64_t offset, int64_t size0, + int64_t size1, int64_t stride0, + int64_t stride1) { + mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1}, + 123); +} + +extern "C" void +mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned, + int64_t offset, int64_t size0, int64_t size1, + int64_t size2, int64_t stride0, int64_t stride1, + int64_t stride2) { + mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2}, + {stride0, stride1, stride2}, 123); }