diff --git a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-and.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-and.mlir @@ -24,10 +24,10 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xi32> to memref - call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xi32> to memref - call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -52,14 +52,12 @@ gpu.terminator } - %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [0, 2] return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) +func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-max.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-max.mlir @@ -24,10 +24,10 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xi32> to memref - call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xi32> to memref - call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -52,14 +52,12 @@ gpu.terminator } - %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [16, 11] return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) +func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-min.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-min.mlir @@ -24,10 +24,10 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xi32> to memref - call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xi32> to memref - call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -52,14 +52,12 @@ gpu.terminator } - %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [0, 2] return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) +func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-op.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-op.mlir @@ -8,7 +8,8 @@ %sx = dim %dst, 2 : memref %sy = dim %dst, 1 : memref %sz = dim %dst, 0 : memref - call @mcuMemHostRegisterMemRef3dFloat(%dst) : (memref) -> () + %cast_dst = memref_cast %dst : memref to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %sy, %block_z = %sz) { %t0 = muli %tz, %block_y : index @@ -21,10 +22,9 @@ store %sum, %dst[%tz, %ty, %tx] : memref gpu.terminator } - %U = memref_cast %dst : memref to memref<*xf32> - call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> () return } -func @mcuMemHostRegisterMemRef3dFloat(%ptr : memref) +func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-or.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-or.mlir @@ -24,10 +24,10 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xi32> to memref - call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xi32> to memref - call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -52,14 +52,12 @@ gpu.terminator } - %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [31, 15] return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) +func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-region.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-region.mlir @@ -6,7 +6,8 @@ %dst = memref_cast %arg : memref<35xf32> to memref %one = constant 1 : index %sx = dim %dst, 0 : memref - call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref) -> () + %cast_dst = memref_cast %dst : memref to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_dst) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %val = index_cast %tx : index to i32 @@ -19,10 +20,9 @@ store %res, %dst[%tx] : memref gpu.terminator } - %U = memref_cast %dst : memref to memref<*xf32> - call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_dst) : (memref<*xf32>) -> () return } -func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) +func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir --- a/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir +++ b/mlir/test/mlir-cuda-runner/all-reduce-xor.mlir @@ -24,10 +24,10 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xi32> to memref - call @mcuMemHostRegisterMemRef2dInt32(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xi32> to memref - call @mcuMemHostRegisterMemRef1dInt32(%cast_sum) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_data) : (memref<*xi32>) -> () + %cast_sum = memref_cast %sum : memref<2xi32> to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_sum) : (memref<*xi32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xi32> store %cst1, %data[%c0, %c1] : memref<2x6xi32> @@ -52,14 +52,12 @@ gpu.terminator } - %ptr = memref_cast %sum : memref<2xi32> to memref<*xi32> - call @print_memref_i32(%ptr) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_sum) : (memref<*xi32>) -> () // CHECK: [31, 1] return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @mcuMemHostRegisterMemRef2dInt32(%ptr : memref) +func @mcuMemHostRegisterInt32(%ptr : memref<*xi32>) func @print_memref_i32(memref<*xi32>) diff --git a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir --- a/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir +++ b/mlir/test/mlir-cuda-runner/gpu-to-cubin.mlir @@ -16,8 +16,8 @@ %arg0 = alloc() : memref<5xf32> %21 = constant 5 : i32 %22 = memref_cast %arg0 : memref<5xf32> to memref - call @mcuMemHostRegisterMemRef1dFloat(%22) : (memref) -> () %23 = memref_cast %22 : memref to memref<*xf32> + call @mcuMemHostRegisterFloat(%23) : (memref<*xf32>) -> () call @print_memref_f32(%23) : (memref<*xf32>) -> () %24 = constant 1.0 : f32 call @other_func(%24, %22) : (f32, memref) -> () @@ -25,5 +25,5 @@ return } -func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) +func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir --- a/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir +++ b/mlir/test/mlir-cuda-runner/multiple-all-reduce.mlir @@ -25,12 +25,12 @@ %c5 = constant 5 : index %c6 = constant 6 : index - %cast_data = memref_cast %data : memref<2x6xf32> to memref - call @mcuMemHostRegisterMemRef2dFloat(%cast_data) : (memref) -> () - %cast_sum = memref_cast %sum : memref<2xf32> to memref - call @mcuMemHostRegisterMemRef1dFloat(%cast_sum) : (memref) -> () - %cast_mul = memref_cast %mul : memref<2xf32> to memref - call @mcuMemHostRegisterMemRef1dFloat(%cast_mul) : (memref) -> () + %cast_data = memref_cast %data : memref<2x6xf32> to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_data) : (memref<*xf32>) -> () + %cast_sum = memref_cast %sum : memref<2xf32> to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_sum) : (memref<*xf32>) -> () + %cast_mul = memref_cast %mul : memref<2xf32> to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_mul) : (memref<*xf32>) -> () store %cst0, %data[%c0, %c0] : memref<2x6xf32> store %cst1, %data[%c0, %c1] : memref<2x6xf32> @@ -57,17 +57,14 @@ gpu.terminator } - %ptr_sum = memref_cast %sum : memref<2xf32> to memref<*xf32> - call @print_memref_f32(%ptr_sum) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_sum) : (memref<*xf32>) -> () // CHECK: [31, 39] - %ptr_mul = memref_cast %mul : memref<2xf32> to memref<*xf32> - call @print_memref_f32(%ptr_mul) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_mul) : (memref<*xf32>) -> () // CHECK: [0, 27720] return } -func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) -func @mcuMemHostRegisterMemRef2dFloat(%ptr : memref) +func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/shuffle.mlir b/mlir/test/mlir-cuda-runner/shuffle.mlir --- a/mlir/test/mlir-cuda-runner/shuffle.mlir +++ b/mlir/test/mlir-cuda-runner/shuffle.mlir @@ -6,7 +6,8 @@ %dst = memref_cast %arg : memref<13xf32> to memref %one = constant 1 : index %sx = dim %dst, 0 : memref - call @mcuMemHostRegisterMemRef1dFloat(%dst) : (memref) -> () + %cast_dest = memref_cast %dst : memref to memref<*xf32> + call @mcuMemHostRegisterFloat(%cast_dest) : (memref<*xf32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -22,10 +23,9 @@ store %value, %dst[%tx] : memref gpu.terminator } - %U = memref_cast %dst : memref to memref<*xf32> - call @print_memref_f32(%U) : (memref<*xf32>) -> () + call @print_memref_f32(%cast_dest) : (memref<*xf32>) -> () return } -func @mcuMemHostRegisterMemRef1dFloat(%ptr : memref) +func @mcuMemHostRegisterFloat(%ptr : memref<*xf32>) func @print_memref_f32(%ptr : memref<*xf32>) diff --git a/mlir/test/mlir-cuda-runner/two-modules.mlir b/mlir/test/mlir-cuda-runner/two-modules.mlir --- a/mlir/test/mlir-cuda-runner/two-modules.mlir +++ b/mlir/test/mlir-cuda-runner/two-modules.mlir @@ -6,7 +6,8 @@ %dst = memref_cast %arg : memref<13xi32> to memref %one = constant 1 : index %sx = dim %dst, 0 : memref - call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref) -> () + %cast_dst = memref_cast %dst : memref to memref<*xi32> + call @mcuMemHostRegisterInt32(%cast_dst) : (memref<*xi32>) -> () gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one) threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) { %t0 = index_cast %tx : index to i32 @@ -19,10 +20,9 @@ store %t0, %dst[%tx] : memref gpu.terminator } - %U = memref_cast %dst : memref to memref<*xi32> - call @print_memref_i32(%U) : (memref<*xi32>) -> () + call @print_memref_i32(%cast_dst) : (memref<*xi32>) -> () return } -func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref) -func @print_memref_i32(%ptr : memref<*xi32>) +func @mcuMemHostRegisterInt32(%memref : memref<*xi32>) +func @print_memref_i32(%memref : memref<*xi32>) diff --git a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp --- a/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp +++ b/mlir/tools/mlir-cuda-runner/cuda-runtime-wrappers.cpp @@ -15,6 +15,7 @@ #include #include +#include "mlir/ExecutionEngine/CRunnerUtils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" @@ -79,15 +80,6 @@ "MemHostRegister"); } -// A struct that corresponds to how MLIR represents memrefs. -template struct MemRefType { - T *basePtr; - T *data; - int64_t offset; - int64_t sizes[N]; - int64_t strides[N]; -}; - // Allows to register a MemRef with the CUDA runtime. Initializes array with // value. Helpful until we have transfer functions implemented. template @@ -110,52 +102,16 @@ mcuMemHostRegister(pointer, count * sizeof(T)); } -extern "C" void mcuMemHostRegisterMemRef1dFloat(float *allocated, - float *aligned, int64_t offset, - int64_t size, int64_t stride) { - mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 1.23f); -} - -extern "C" void mcuMemHostRegisterMemRef2dFloat(float *allocated, - float *aligned, int64_t offset, - int64_t size0, int64_t size1, - int64_t stride0, - int64_t stride1) { - mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1}, - 1.23f); -} - -extern "C" void mcuMemHostRegisterMemRef3dFloat(float *allocated, - float *aligned, int64_t offset, - int64_t size0, int64_t size1, - int64_t size2, int64_t stride0, - int64_t stride1, - int64_t stride2) { - mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2}, - {stride0, stride1, stride2}, 1.23f); -} - -extern "C" void mcuMemHostRegisterMemRef1dInt32(int32_t *allocated, - int32_t *aligned, - int64_t offset, int64_t size, - int64_t stride) { - mcuMemHostRegisterMemRef(aligned + offset, {size}, {stride}, 123); -} - -extern "C" void mcuMemHostRegisterMemRef2dInt32(int32_t *allocated, - int32_t *aligned, - int64_t offset, int64_t size0, - int64_t size1, int64_t stride0, - int64_t stride1) { - mcuMemHostRegisterMemRef(aligned + offset, {size0, size1}, {stride0, stride1}, - 123); +extern "C" void mcuMemHostRegisterFloat(int64_t rank, void *ptr) { + auto *desc = static_cast *>(ptr); + auto sizes = llvm::ArrayRef(desc->sizes, rank); + auto strides = llvm::ArrayRef(desc->sizes + rank, rank); + mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 1.23f); } -extern "C" void -mcuMemHostRegisterMemRef3dInt32(int32_t *allocated, int32_t *aligned, - int64_t offset, int64_t size0, int64_t size1, - int64_t size2, int64_t stride0, int64_t stride1, - int64_t stride2) { - mcuMemHostRegisterMemRef(aligned + offset, {size0, size1, size2}, - {stride0, stride1, stride2}, 123); +extern "C" void mcuMemHostRegisterInt32(int64_t rank, void *ptr) { + auto *desc = static_cast *>(ptr); + auto sizes = llvm::ArrayRef(desc->sizes, rank); + auto strides = llvm::ArrayRef(desc->sizes + rank, rank); + mcuMemHostRegisterMemRef(desc->data + desc->offset, sizes, strides, 123); }