diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_cuda_runner: + config.unsupported = True diff --git a/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/GPU/CUDA/sparse-mma-2-4-f16.mlir @@ -0,0 +1,310 @@ +// RUN: mlir-opt --pass-pipeline="builtin.module(gpu.module(strip-debuginfo,convert-gpu-to-nvvm,convert-nvgpu-to-nvvm,affine-expand-index-ops,lower-affine,convert-arith-to-llvm),convert-vector-to-llvm,canonicalize,cse,gpu.module(gpu-to-cubin{chip=sm_80 features=+ptx71}))" %s \ +// RUN: | mlir-opt --convert-scf-to-cf -convert-cf-to-llvm --convert-vector-to-llvm --convert-arith-to-llvm --gpu-to-llvm --reconcile-unrealized-casts \ +// RUN: | mlir-cpu-runner \ +// RUN: --shared-libs=%mlir_cuda_runtime \ +// RUN: --shared-libs=%mlir_runner_utils \ +// RUN: --e main --entry-point-result=void \ +// RUN: | FileCheck %s + +module attributes {gpu.container_module} { + + // Kernels that run on the device. + + gpu.module @kernels { + + // + // An NVidia GPU kernel to compute + // C = A x B + // (or, technically, D = A x B + C) + // using 2:4 structured sparsity for A. + // + // This kernel provides building block for sparse compilation of a larger + // enveloping matrix multiplication computation on a GPU. + // + // Operand A values (2:4 sparse): row major format, logically "16x32xf16" + // but "16x16xf16" after compression + // + // Operand A metadata. + // - Threads 2i -> col 0 + // 2i + 1 -> col 1 + // + // Operand B (dense): column major format. + // + // Operand C (accum): assumed zero on entry, used as output. + // + gpu.func @mma_sp_sync_f16_16832( + %argA: memref<16x16xf16>, + %argA_meta: memref<16x2xi16>, + %argB: memref<8x32xf16>, + %argC: memref<16x8xf16>) kernel { + %f0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c4 = arith.constant 4 : index + %c8 = arith.constant 8 : index + + // Assume we have a linear thread id and the kernel launches 32 threads (1 warp). + // So CUDA launch would be threadblock = (32, 1, 1), grid = (1, 1, 1) + %lane_id = gpu.thread_id x + // Which group of 4 threads do we belong to? + %quad_id = affine.apply affine_map<()[s0]->(s0 floordiv 4)>()[%lane_id] + // Are we even group or odd group? + %pair_id = affine.apply affine_map<()[s0]->(s0 mod 2)>()[%lane_id] + + // Now we have + // MMA lane=0 quad=0 pair=0 + // MMA lane=1 quad=0 pair=1 + // MMA lane=2 quad=0 pair=0 + // MMA lane=3 quad=0 pair=1 + // MMA lane=4 quad=1 pair=0 + // MMA lane=5 quad=1 pair=1 + // ... + // MMA lane=30 quad=7 pair=2 + // MMA lane=31 quad=7 pair=1 + // + //gpu.printf "MMA lane=%lld quad=%lld pair=%lld\n" %lane_id, %quad_id, %pair_id : index, index, index + + // Load and combine the two pieces of i16 metadata required. Obviously, it's + // possible to re-pack the metadata before launching the kernel in order + // to eliminate this cost and load a single i32 operand. This just shows + // how to put them together if you do the naive load per the diagram in + // the PTX docs. Technically only the first two threads in each quad need + // to do this. + %meta_row0 = arith.addi %quad_id, %c0 : index + %meta_row1 = arith.addi %quad_id, %c8 : index + %meta_A_per_thread0 = memref.load %argA_meta[%meta_row0, %pair_id] : memref<16x2xi16> + %meta_A_per_thread1 = memref.load %argA_meta[%meta_row1, %pair_id] : memref<16x2xi16> + %meta_init = arith.constant dense<0> : vector<2xi16> + %meta_low = vector.insert %meta_A_per_thread0, %meta_init[0] : i16 into vector<2xi16> + %meta = vector.insert %meta_A_per_thread1, %meta_low[1] : i16 into vector<2xi16> + + // LOAD A + // Load the actual fragments for the dense values. This can be done using ldmatrix, + // but here we just do naive individual loads. + %A_row0, %A_col0 = affine.delinearize_index %lane_id into (%c8, %c4) : index, index + %A_row8 = arith.addi %A_row0, %c8 : index + %A_col8 = arith.addi %A_col0, %c8 : index + %A_quad00 = vector.transfer_read %argA[%A_row0, %A_col0], %f0 {in_bounds = [true]} : memref<16x16xf16>, vector<2xf16> + %A_quad10 = vector.transfer_read %argA[%A_row8, %A_col0], %f0 {in_bounds = [true]} : memref<16x16xf16>, vector<2xf16> + %A_quad01 = vector.transfer_read %argA[%A_row0, %A_col8], %f0 {in_bounds = [true]} : memref<16x16xf16>, vector<2xf16> + %A_quad11 = vector.transfer_read %argA[%A_row8, %A_col8], %f0 {in_bounds = [true]} : memref<16x16xf16>, vector<2xf16> + %A_init0 = arith.constant dense<0.0> : vector<4x2xf16> + %A_data0 = vector.insert %A_quad00, %A_init0[0] : vector<2xf16> into vector<4x2xf16> + %A_data1 = vector.insert %A_quad10, %A_data0[1] : vector<2xf16> into vector<4x2xf16> + %A_data2 = vector.insert %A_quad01, %A_data1[2] : vector<2xf16> into vector<4x2xf16> + %A_data3 = vector.insert %A_quad11, %A_data2[3] : vector<2xf16> into vector<4x2xf16> + + // LOAD B + // Load the actual fragments for the dense values. This can be done using ldmatrix, + // but here we just do naive individual loads. + %B_row0 = affine.apply affine_map<()[s0]->( (s0 mod 4) * 2 )>()[%lane_id] + %B_row8 = affine.apply affine_map<()[s0]->( (s0 mod 4) * 2 + 8 )>()[%lane_id] + %B_row16 = affine.apply affine_map<()[s0]->( (s0 mod 4) * 2 + 16)>()[%lane_id] + %B_row24 = affine.apply affine_map<()[s0]->( (s0 mod 4) * 2 + 24)>()[%lane_id] + %B_col = affine.apply affine_map<()[s0]->(s0 floordiv 4)>()[%lane_id] + %B_quad0 = vector.transfer_read %argB[%B_col, %B_row0], %f0 {in_bounds = [true]} : memref<8x32xf16>, vector<2xf16> + %B_quad1 = vector.transfer_read %argB[%B_col, %B_row8], %f0 {in_bounds = [true]} : memref<8x32xf16>, vector<2xf16> + %B_quad2 = vector.transfer_read %argB[%B_col, %B_row16], %f0 {in_bounds = [true]} : memref<8x32xf16>, vector<2xf16> + %B_quad3 = vector.transfer_read %argB[%B_col, %B_row24], %f0 {in_bounds = [true]} : memref<8x32xf16>, vector<2xf16> + %B_init0 = arith.constant dense<0.0> : vector<4x2xf16> + %B_data0 = vector.insert %B_quad0, %B_init0[0] : vector<2xf16> into vector<4x2xf16> + %B_data1 = vector.insert %B_quad1, %B_data0[1] : vector<2xf16> into vector<4x2xf16> + %B_data2 = vector.insert %B_quad2, %B_data1[2] : vector<2xf16> into vector<4x2xf16> + %B_data3 = vector.insert %B_quad3, %B_data2[3] : vector<2xf16> into vector<4x2xf16> + + // For now just say accum is a zero-d register + %accum = arith.constant dense<0.0> : vector<2x2xf16> + + // Sparsity selector. For 16x8x32, "0" means + // Threads T0/T1 within each group of four threads + // contribute metadata. + %d = nvgpu.mma.sp.sync(%A_data3, %B_data3, %accum) + metadata(%meta) + {mmaShape = [16, 8, 32]} : (vector<4x2xf16>, vector<4x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + + // Transfer result back into host memory. + %C_0 = vector.extract %d[0] : vector<2x2xf16> + %C_1 = vector.extract %d[1] : vector<2x2xf16> + vector.transfer_write %C_0, %argC[%meta_row0, %c0] {in_bounds = [true]} : vector<2xf16>, memref<16x8xf16> + vector.transfer_write %C_1, %argC[%meta_row1, %c0] {in_bounds = [true]} : vector<2xf16>, memref<16x8xf16> + gpu.return + } + } + + // Code than runs on the host. + + // + // This test performs a matrix multiplication + // C = A x B + // using NVidia 2:4 structured sparsity for A. + // + func.func @main() { + %f0 = arith.constant 0.0 : f16 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + + // Matrices A, B, C (16x32, 32x8, 16x8) + %a = memref.alloc() : memref<16x16xf16> // 16x32 but 2:4, row-major + %b = memref.alloc() : memref<8x32xf16> // regular dense column-major + %c = memref.alloc() : memref<16x8xf16> // accumulator row-major + + // Metadata for A. + %m = memref.alloc() : memref<16x2xi16> + + // + // Setup matrix A. + // + scf.for %ai = %c0 to %c16 step %c1 { + scf.for %aj = %c0 to %c16 step %c1 { + %a0 = arith.addi %ai, %aj : index + %a1 = arith.subi %c64, %a0 : index + %a2 = arith.index_cast %a1 : index to i32 + %a3 = arith.sitofp %a2 : i32 to f16 + memref.store %a3, %a[%ai, %aj] : memref<16x16xf16> + } + } + + // + // Setup metadata for matrix A. + // + // Here we assume that all 2:4 elements are in pos 0 and 2, + // viz. in matrix + // | A 0 B 0 | + // { 0 2 } + // + // 10 00 10 00 10 00 10 00 10 00 10 00 = 0x8888 // WHAT IS THE ENDIANESS HERE?! + // + //%bits = arith.constant 0x8888 : i16 + //%bits = arith.constant 0x2222 : i16 + %bits = arith.constant 0x2020 : i16 + scf.for %mi = %c0 to %c16 step %c1 { + memref.store %bits, %m[%mi, %c0] : memref<16x2xi16> + memref.store %bits, %m[%mi, %c1] : memref<16x2xi16> + } + + // + // Setup matrix B. + // + scf.for %bi = %c0 to %c8 step %c1 { + scf.for %bj = %c0 to %c32 step %c1 { + %b0 = arith.addi %bi, %bj : index + %b1 = arith.addi %b0, %c16 : index + %b2 = arith.index_cast %b1 : index to i32 + %b3 = arith.sitofp %b2 : i32 to f16 + memref.store %b3, %b[%bi, %bj] : memref<8x32xf16> + } + } + + // + // Reset matrix C. + // + scf.for %ci = %c0 to %c16 step %c1 { + scf.for %cj = %c0 to %c8 step %c1 { + memref.store %f0, %c[%ci, %cj] : memref<16x8xf16> + } + } + + // + // Sanity check on **compressed** input matrix A. + // + // Note that it really is a 16x32 matrix: + // | 64 0 63 0 62 0 .... + // | 63 0 62 0 61 0 ... + // + // CHECK: ( 64, 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49 ) + // CHECK-NEXT: ( 63, 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48 ) + // CHECK-NEXT: ( 62, 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47 ) + // CHECK-NEXT: ( 61, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46 ) + // CHECK-NEXT: ( 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45 ) + // CHECK-NEXT: ( 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44 ) + // CHECK-NEXT: ( 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43 ) + // CHECK-NEXT: ( 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42 ) + // CHECK-NEXT: ( 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41 ) + // CHECK-NEXT: ( 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40 ) + // CHECK-NEXT: ( 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39 ) + // CHECK-NEXT: ( 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38 ) + // CHECK-NEXT: ( 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37 ) + // CHECK-NEXT: ( 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36 ) + // CHECK-NEXT: ( 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35 ) + // CHECK-NEXT: ( 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34 ) + // + scf.for %pai = %c0 to %c16 step %c1 { + %pa0 = vector.transfer_read %a[%pai, %c0], %f0 : memref<16x16xf16>, vector<16xf16> + vector.print %pa0 : vector<16xf16> + } + + // + // Sanity check on input matrix 32x8 B. + // Note that this is really printed as B^T + // + // CHECK-NEXT: ( 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47 ) + // CHECK-NEXT: ( 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48 ) + // CHECK-NEXT: ( 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49 ) + // CHECK-NEXT: ( 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50 ) + // CHECK-NEXT: ( 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51 ) + // CHECK-NEXT: ( 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52 ) + // CHECK-NEXT: ( 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53 ) + // CHECK-NEXT: ( 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54 ) + // + scf.for %pbi = %c0 to %c8 step %c1 { + %pb0 = vector.transfer_read %b[%pbi, %c0], %f0 : memref<8x32xf16>, vector<32xf16> + vector.print %pb0 : vector<32xf16> + } + + // Maps the provided host buffer into the device address space. + // Writes from the host are guaranteed to be visible to device + // kernels that are launched afterwards. Writes from the device + // are guaranteed to be visible on the host after synchronizing + // with the device kernel completion. + %cast_a = memref.cast %a : memref<16x16xf16> to memref<*xf16> + gpu.host_register %cast_a : memref<*xf16> + %cast_m = memref.cast %m : memref<16x2xi16> to memref<*xi16> + gpu.host_register %cast_m : memref<*xi16> + %cast_b = memref.cast %b : memref<8x32xf16> to memref<*xf16> + gpu.host_register %cast_b : memref<*xf16> + %cast_c = memref.cast %c : memref<16x8xf16> to memref<*xf16> + gpu.host_register %cast_c : memref<*xf16> + + // Call the kernel, using a single warp of 32 threads. + %t1 = arith.constant 1 : index + %t32 = arith.constant 32 : index + gpu.launch_func + @kernels::@mma_sp_sync_f16_16832 + blocks in (%t1, %t1, %t1) // gridSizeX,Y,Z + threads in (%t32, %t1, %t1) // blockSizeX,Y,Z + args(%a : memref<16x16xf16>, + %m : memref<16x2xi16>, + %b : memref<8x32xf16>, + %c : memref<16x8xf16>) + + // + // Verify computed matrix C. + // + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // CHECK-NEXT: ( 0, 0, 0, 0, 0, 0, 0, 0 ) + // + scf.for %pci = %c0 to %c16 step %c1 { + %pc0 = vector.transfer_read %c[%pci, %c0], %f0 : memref<16x8xf16>, vector<8xf16> + vector.print %pc0 : vector<8xf16> + } + + return + } +}