diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -231,6 +231,9 @@ } LogicalResult MmaSparseSyncOp::verify() { + unsigned sparsitySelector = getSparsitySelector(); + if (sparsitySelector > 1) + return emitOpError() << "sparsity selector should be 0 or 1"; return verifyMmaSyncOp(this->getOperation(), getMatrixA(), getMatrixB(), getMatrixC(), getMmaShapeAsArray(), getOperation()->hasAttr(getTf32EnabledAttrName()), diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -385,6 +385,29 @@ // ----- +// CHECK-LABEL: func @mma_sp_sync_f16_16816_01( +func.func @mma_sp_sync_f16_16816_01(%arg0: vector<2x2xf16>, + %arg1: vector<2x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + // + // As above, but with sparsity selection 0x01. + // + // CHECK: %[[sparseMetadata:.+]] = llvm.bitcast %{{.+}} : vector<2xi16> to i32 + // CHECK: %[[d:.+]] = llvm.inline_asm has_side_effects asm_dialect = att + // CHECK-SAME: "mma.sp.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {$0,$1},{$2,$3},{$4,$5},{$6,$7},$8,0x1;" + // CHECK-SAME: "=r,=r,r,r,r,r,r,r,r" + // CHECK-SAME: %[[sparseMetadata]] : + // CHECK-SAME: -> !llvm.struct<(vector<2xf16>, vector<2xf16>)> + + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) + {mmaShape = [16, 8, 16], sparsitySelector = 1 : i32} : + (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} + +// ----- + // CHECK-LABEL: func @mma_sp_sync_i8_16864( func.func @mma_sp_sync_i8_16864(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -130,7 +130,6 @@ nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xf32> return } - // ----- func.func @async_cp_memref_type(%dst : memref<16xi32, 3>, %src : memref<16xf32>, %i : index) -> () { @@ -138,7 +137,6 @@ nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16xi32, 3> return } - // ----- func.func @async_cp_num_src_indices(%dst : memref<16xf32, 3>, %src : memref<16x16xf32>, %i : index) -> () { @@ -146,7 +144,6 @@ nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16x16xf32> to memref<16xf32, 3> return } - // ----- func.func @async_cp_num_dst_indices(%dst : memref<16x16xf32, 3>, %src : memref<16xf32>, %i : index) -> () { @@ -154,7 +151,6 @@ nvgpu.device_async_copy %src[%i], %dst[%i], 16 : memref<16xf32> to memref<16x16xf32, 3> return } - // ----- func.func @async_cp_num_src_stride( @@ -166,7 +162,6 @@ memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>> to memref<200x100xf32, 3> return } - // ----- func.func @async_cp_num_dst_stride( @@ -178,3 +173,15 @@ memref<200x100xf32> to memref<200x100xf32, affine_map<(d0, d1) -> (200*d0 + 2*d1)>, 3> return } +// ----- + +// 42 is never the answer! +func.func @mma_sp_sync_f16_16816(%arg0: vector<2x2xf16>, + %arg1: vector<2x2xf16>, + %arg2: vector<2x2xf16>, + %arg3: vector<2xi16>) -> vector<2x2xf16> { + // expected-error @+1 {{'nvgpu.mma.sp.sync' op sparsity selector should be 0 or 1}} + %d = nvgpu.mma.sp.sync(%arg0, %arg1, %arg2) metadata(%arg3) {mmaShape = [16, 8, 16], sparsitySelector = 42 : i32} : + (vector<2x2xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +}