diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-2-4-lib.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-2-4-lib.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-matmul-2-4-lib.mlir @@ -0,0 +1,219 @@ +// +// NOTE: this test requires gpu-sm80 +// +// RUN: mlir-opt \ +// RUN: --pass-pipeline="builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm,affine-expand-index-ops,lower-affine,convert-arith-to-llvm),convert-vector-to-llvm,canonicalize,cse,gpu.module(gpu-to-cubin{chip=sm_80 features=+ptx71}))" \ +// RUN: %s \ +// RUN: | mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm \ +// RUN: --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_c_runner_utils \ +// RUN: --e main --entry-point-result=void \ +// RUN: | FileCheck %s + +module attributes {gpu.container_module} { + + // Kernels that run on the device. + + func.func @sampled_matmul(%a : memref<16x32xf16>, + %b : memref<8x32xf16>, + %c : memref<16x8xf16>) { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1048576 = arith.constant 1048576 : index + %token0 = gpu.wait async + %d_a, %token1 = gpu.alloc async [%token0] () : memref<16x32xf16> + %d_b, %token2 = gpu.alloc async [%token1] () : memref<8x32xf16> + %d_c, %token3 = gpu.alloc async [%token2] () : memref<16x8xf16> + %token4 = gpu.memcpy async [%token3] %d_a, %a : memref<16x32xf16>, memref<16x32xf16> + %token5 = gpu.memcpy async [%token4] %d_b, %b : memref<8x32xf16>, memref<8x32xf16> + %token6 = gpu.memcpy async [%token5] %d_c, %c : memref<16x8xf16>, memref<16x8xf16> + + %mem1, %token7 = gpu.alloc async [%token6] (%c1048576) : memref + %mem2, %token8 = gpu.alloc async [%token7] (%c1048576) : memref + %mem3, %token9 = gpu.alloc async [%token8] (%c1048576) : memref + %env, %token10 = gpu.create_sparse_env async [%token9] + %spmat, %token11 = gpu.create_2to4_spmat async [%token10] %env, %c16, %c32, %d_a: memref<16x32xf16> + %dnmat, %token12 = gpu.create_dn_mat async [%token11] %env, %c8, %c32, %d_b : memref<8x32xf16> + %dnmat2, %token13 = gpu.create_dn_mat async [%token12] %env, %c16, %c8, %d_c : memref<16x8xf16> + gpu.wait [%token13] + %bufferSzs, %token14 = gpu.spmm_buffer_size async [%token13] %env, %spmat{NON_TRANSPOSE}, %dnmat{TRANSPOSE}, %dnmat2 : tuple into f16 + // TODO: implement the op to unpack tuple %bufferSzs. + %token15 = gpu.spmm async [%token14] %env, %spmat{NON_TRANSPOSE}, %dnmat{TRANSPOSE}, %dnmat2, %mem1, %mem2, %mem3 : memref,memref,memref into f16 + %token16 = gpu.destroy_sp_mat async [%token15] %spmat + %token17 = gpu.destroy_dn_mat async [%token16] %dnmat + %token18 = gpu.destroy_sparse_env async [%token17] %env + %token19 = gpu.memcpy async [%token18] %c, %d_c : memref<16x8xf16>, memref<16x8xf16> + %token20 = gpu.dealloc async [%token19] %d_c : memref<16x8xf16> + %token21 = gpu.dealloc async [%token20] %d_b : memref<8x32xf16> + %token22 = gpu.dealloc async [%token21] %d_a : memref<16x32xf16> + %token23 = gpu.dealloc async [%token22] %mem3 : memref + %token24 = gpu.dealloc async [%token23] %mem2 : memref + %token25 = gpu.dealloc async [%token24] %mem1 : memref + gpu.wait [%token25] + return + } + + // Code than runs on the host. + + // + // This test performs a matrix multiplication + // C = A x B + // using NVidia 2:4 structured sparsity for A. + // + func.func @main() { + %f0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + // Matrices A, B, C (16x32, 32x8, 16x8). + %a = memref.alloc() : memref<16x32xf16> // 16x32 but 2:4, row-major + %b = memref.alloc() : memref<8x32xf16> // regular dense column-major + %c = memref.alloc() : memref<16x8xf16> // accumulator row-major + + // + // Setup matrix A. + // + scf.for %ai = %c0 to %c16 step %c1 { + scf.for %aj = %c0 to %c16 step %c1 { + %cf0 = arith.constant 0.0: f16 + %a0 = arith.addi %ai, %aj : index + %a1 = arith.addi %a0, %c1 : index + %a2 = arith.index_cast %a1 : index to i32 + %a3 = arith.sitofp %a2 : i32 to f16 + %ajj = arith.muli %aj, %c2 : index + %ajj2 = arith.addi %ajj, %c1 : index + memref.store %a3, %a[%ai, %ajj] : memref<16x32xf16> + memref.store %cf0, %a[%ai, %ajj2] : memref<16x32xf16> + } + } + + // + // Setup matrix B. + // + scf.for %bi = %c0 to %c8 step %c1 { + scf.for %bj = %c0 to %c32 step %c1 { + %b0 = arith.subi %bi, %bj : index + %b1 = arith.index_cast %b0 : index to i32 + %b2 = arith.sitofp %b1 : i32 to f16 + memref.store %b2, %b[%bi, %bj] : memref<8x32xf16> + } + } + + // + // Reset matrix C. + // + scf.for %ci = %c0 to %c16 step %c1 { + scf.for %cj = %c0 to %c8 step %c1 { + memref.store %f0, %c[%ci, %cj] : memref<16x8xf16> + } + } + + // + // Sanity check on **compressed** input matrix A. + // + // Note that it really is a 16x32 matrix: + // | 1 0 2 0 3 0 ... + // | 2 0 3 0 4 0 ... + // etc. + // + // CHECK: ( 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16 ) + // CHECK-NEXT: ( 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17 ) + // CHECK-NEXT: ( 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18 ) + // CHECK-NEXT: ( 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19 ) + // CHECK-NEXT: ( 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20 ) + // CHECK-NEXT: ( 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21 ) + // CHECK-NEXT: ( 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22 ) + // CHECK-NEXT: ( 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23 ) + // CHECK-NEXT: ( 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24 ) + // CHECK-NEXT: ( 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25 ) + // CHECK-NEXT: ( 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26 ) + // CHECK-NEXT: ( 12, 0, 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27 ) + // CHECK-NEXT: ( 13, 0, 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28 ) + // CHECK-NEXT: ( 14, 0, 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29 ) + // CHECK-NEXT: ( 15, 0, 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30 ) + // CHECK-NEXT: ( 16, 0, 17, 0, 18, 0, 19, 0, 20, 0, 21, 0, 22, 0, 23, 0, 24, 0, 25, 0, 26, 0, 27, 0, 28, 0, 29, 0, 30, 0, 31 ) + // + scf.for %pai = %c0 to %c16 step %c1 { + %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x32xf16>, vector<32xf16> + vector.print %pa0 : vector<32xf16> + } + + // + // Sanity check on input matrix 32x8 B. + // Note that this is really shown as B^T + // + // CHECK-NEXT: ( 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31 ) + // CHECK-NEXT: ( 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30 ) + // CHECK-NEXT: ( 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29 ) + // CHECK-NEXT: ( 3, 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28 ) + // CHECK-NEXT: ( 4, 3, 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27 ) + // CHECK-NEXT: ( 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26 ) + // CHECK-NEXT: ( 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25 ) + // CHECK-NEXT: ( 7, 6, 5, 4, 3, 2, 1, 0, -1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24 ) + // + // + scf.for %pbi = %c0 to %c8 step %c1 { + %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<8x32xf16>, vector<32xf16> + vector.print %pb0 : vector<32xf16> + } + + // Maps the provided host buffers into the device address space. + // Writes from the host are guaranteed to be visible to device + // kernels that are launched afterwards. Writes from the device + // are guaranteed to be visible on the host after synchronizing + // with the device kernel completion. + %cast_a = memref.cast %a : memref<16x32xf16> to memref<*xf16> + gpu.host_register %cast_a : memref<*xf16> + %cast_b = memref.cast %b : memref<8x32xf16> to memref<*xf16> + gpu.host_register %cast_b : memref<*xf16> + %cast_c = memref.cast %c : memref<16x8xf16> to memref<*xf16> + gpu.host_register %cast_c : memref<*xf16> + + // Call the kernel, using a single warp of 32 threads. + %t1 = arith.constant 1 : index + %t32 = arith.constant 32 : index + call @sampled_matmul (%a, %b, %c): (memref<16x32xf16>, memref<8x32xf16>, memref<16x8xf16>) -> () + //call @sampled_matmul (%cast_a, %cast_b, %cast_c): (memref, memref, memref) -> () + // Unmaps the host buffers. + gpu.host_unregister %cast_a : memref<*xf16> + gpu.host_unregister %cast_b : memref<*xf16> + gpu.host_unregister %cast_c : memref<*xf16> + + // + // Verify computed matrix C. + // + // CHECK-NEXT: ( -2720, -2584, -2448, -2312, -2176, -2040, -1904, -1768 ) + // CHECK-NEXT: ( -2960, -2808, -2656, -2504, -2352, -2200, -2048, -1896 ) + // CHECK-NEXT: ( -3200, -3032, -2864, -2696, -2528, -2360, -2192, -2024 ) + // CHECK-NEXT: ( -3440, -3256, -3072, -2888, -2704, -2520, -2336, -2152 ) + // CHECK-NEXT: ( -3680, -3480, -3280, -3080, -2880, -2680, -2480, -2280 ) + // CHECK-NEXT: ( -3920, -3704, -3488, -3272, -3056, -2840, -2624, -2408 ) + // CHECK-NEXT: ( -4160, -3928, -3696, -3464, -3232, -3000, -2768, -2536 ) + // CHECK-NEXT: ( -4400, -4152, -3904, -3656, -3408, -3160, -2912, -2664 ) + // CHECK-NEXT: ( -4640, -4376, -4112, -3848, -3584, -3320, -3056, -2792 ) + // CHECK-NEXT: ( -4880, -4600, -4320, -4040, -3760, -3480, -3200, -2920 ) + // CHECK-NEXT: ( -5120, -4824, -4528, -4232, -3936, -3640, -3344, -3048 ) + // CHECK-NEXT: ( -5360, -5048, -4736, -4424, -4112, -3800, -3488, -3176 ) + // CHECK-NEXT: ( -5600, -5272, -4944, -4616, -4288, -3960, -3632, -3304 ) + // CHECK-NEXT: ( -5840, -5496, -5152, -4808, -4464, -4120, -3776, -3432 ) + // CHECK-NEXT: ( -6080, -5720, -5360, -5000, -4640, -4280, -3920, -3560 ) + // CHECK-NEXT: ( -6320, -5944, -5568, -5192, -4816, -4440, -4064, -3688 ) + // + scf.for %pci = %c0 to %c16 step %c1 { + %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x8xf16>, vector<8xf16> + vector.print %pc0 : vector<8xf16> + } + + return + } +}