Index: mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td =================================================================== --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -54,7 +54,10 @@ Op {} def NVGPU_LdMatrixOp : NVGPU_Op<"ldmatrix", - [MemoryEffects<[MemRead]>]> { + [ + MemoryEffects<[MemRead]>, + PredOpTrait<"srcMemref and res have same element type", TCresVTEtIsSameAsOp<0, 0>>, + ]> { let description = [{ The `nvgpu.ldmatrix` op represents loading a matrix fragment from memory. The load source and result type must be compatible with lowering @@ -79,6 +82,8 @@ let assemblyFormat = [{ $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) }]; + + let hasVerifier = 1; } def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ Index: mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp =================================================================== --- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -88,6 +88,10 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_MmaSyncOp +//===----------------------------------------------------------------------===// + LogicalResult MmaSyncOp::verify() { // Fundamental tensor core mma.sync op @@ -186,5 +190,55 @@ return success(); } +//===----------------------------------------------------------------------===// +// NVGPU_LdMatrixOp +//===----------------------------------------------------------------------===// +LogicalResult LdMatrixOp::verify() { + + // src: memref to shared memory + auto srcMemref = getSrcMemref().getType().cast(); + + // dst: vector registers + auto resVector = getRes().getType().cast(); + ArrayRef resShape = resVector.getShape(); + + Type resType = resVector.getElementType(); + int64_t elementBitWidth = resType.getIntOrFloatBitWidth(); + + // ldmatrix loads 32 bits into vector registers per 8-by-8 tile per thread + int64_t numElementsPer32b = 32 / elementBitWidth; + + // number of ldmatrix 8-by-8 tiles + int64_t numTiles = getNumTiles(); + + // transpose elements in vector registers at 16b granularity when true + bool isTranspose = getTranspose(); + + // address space id for shared memory + unsigned smemAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); + + // + // verification + // + + if (srcMemref.getMemorySpaceAsInt() != smemAddressSpace) + return emitError() + << "expected nvgpu.ldmatrix srcMemref must have memory space " + << smemAddressSpace; + else if (elementBitWidth > 32) + return emitError() << "nvgpu.ldmatrix works for 32b or lower"; + else if (isTranspose && elementBitWidth != 16) + return emitError() + << "nvgpu.ldmatrix transpose works only at 16b granularity"; + else if (!(resShape[1] == numElementsPer32b)) + return emitError() << "expected vector register shape[1] = " + << numElementsPer32b; + else if (!(resShape[0] == numTiles)) + return emitError() + << "expected vector register shape[0] and numTiles to match"; + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" Index: mlir/test/Dialect/NVGPU/invalid.mlir =================================================================== --- mlir/test/Dialect/NVGPU/invalid.mlir +++ mlir/test/Dialect/NVGPU/invalid.mlir @@ -1,4 +1,53 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s + +func.func @ldmatrix_address_space_f16_x4(%arg0: memref<128x128xf16, 2>) -> vector<4x1xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected nvgpu.ldmatrix srcMemref must have memory space 3}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 2> -> vector<4x1xf16> + return %a : vector<4x1xf16> +} +// ----- + +func.func @ldmatrix_num_elements_f16_x4(%arg0: memref<128x128xf16, 3>) -> vector<4x1xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[1] = 2}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<4x1xf16> + return %a : vector<4x1xf16> +} +// ----- + +func.func @ldmatrix_num_tiles_f16_x4(%arg0: memref<128x128xf16, 3>) -> vector<2x2xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[0] and numTiles to match}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf16, 3> -> vector<2x2xf16> + return %a : vector<2x2xf16> +} +// ----- + +func.func @ldmatrix_num_tiles_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf32> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{expected vector register shape[1] = 1}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x2xf32> + return %a : vector<4x2xf32> +} +// ----- + +func.func @ldmatrix_trans_f32_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x1xf32> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{nvgpu.ldmatrix transpose works only at 16b granularity}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = true, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x1xf32> + return %a : vector<4x1xf32> +} +// ----- + +func.func @ldmatrix_type_x4(%arg0: memref<128x128xf32, 3>) -> vector<4x2xf16> { + %c0 = arith.constant 0 : index + // expected-error @+1 {{'nvgpu.ldmatrix' op failed to verify that srcMemref and res have same element type}} + %a = nvgpu.ldmatrix %arg0[%c0, %c0] {transpose = false, numTiles = 4 : i32} : memref<128x128xf32, 3> -> vector<4x2xf16> + return %a : vector<4x2xf16> +} +// ----- + func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { // expected-error @+1 {{expected 256 warp-wide matrix A elements}} %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16>