Index: mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td =================================================================== --- mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -112,6 +112,8 @@ `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) }]; + + let hasVerifier = 1; } Index: mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp =================================================================== --- mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -88,5 +88,117 @@ return success(); } +LogicalResult MmaSyncOp::verify() { + + // Fundamental tensor core mma.sync op + // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core + // operation is of shape: 8-by-8-by-128b. F64 is an exception. The + // verification for mma.sync covering various shapes and data types is based + // on the fundamental tensor core operionation. + constexpr int kThreads = 32; // 32 threads per warp + int64_t shapeM = 8; + int64_t shapeN = 8; + int64_t shapeK; // set based on data type (128b for all data types except F64) + + // Number of elements A, B, and C per thread per fundamental tensor core tile + int64_t numElementA; // set based on data type (32b except F64) + int64_t numElementB; // set based on data type (32b except F64) + int64_t numElementC{2}; // two accumulator elements per fundamental tile + + // nvgpu.mma.sync vector operands (per thread) + auto aVector = getMatrixA().getType().cast(); + auto bVector = getMatrixB().getType().cast(); + auto cVector = getMatrixC().getType().cast(); + + // vector shapes + auto aShape = aVector.getShape(); + auto bShape = bVector.getShape(); + auto cShape = cVector.getShape(); + + // vector element type + auto aType = aVector.getElementType(); + auto bType = bVector.getElementType(); + auto cType = cVector.getElementType(); + + // nvgpu.mma.sync shape (per 32 threads or per warp) + int64_t m = getMmaShape()[0].cast().getInt(); + int64_t n = getMmaShape()[1].cast().getInt(); + int64_t k = getMmaShape()[2].cast().getInt(); + + if (aType.getTypeID() != bType.getTypeID()) { + return emitError() << "expected same data type for matrix A and matrix B"; + } + + if (aType.isF64()) { + // exception to 8-by-8-128b fundamental tensor core tile size + shapeK = 4; + numElementA = 1; + numElementB = 1; + } + + else if (aType.isF32() || aType.isBF16() || aType.isF16() || + aType.isInteger(8)) { + // 8-by-8-128b fundamental tensor core tile size + int operandBitwidth = aType.getIntOrFloatBitWidth(); + shapeK = 128 / operandBitwidth; // 128b wide shapeK + numElementA = 32 / operandBitwidth; // 32b wide operand A + numElementB = 32 / operandBitwidth; // 32b wide operand B + } else { + return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " + "supported by nvgpu.mma.sync"; + } + + // + // Basic verfication + // + + // verify warp-wide size for vector a + if (aShape[0] * aShape[1] * kThreads != m * k) { + return emitOpError() << "expected " << m * k + << " warp-wide matrix A elements"; + } + + // verify warp-wide size for vector b + if (bShape[0] * bShape[1] * kThreads != k * n) { + return emitOpError() << "expected " << k * n + << " warp-wide matrix B elements"; + } + + // verify warp-wide size for vector c + if (cShape[0] * cShape[1] * kThreads != m * n) { + return emitOpError() << "expected " << m * n + << " warp-wide matrix C elements"; + } + + // + // Extended verification + // + + // tiles of fundamental tensor core operations + int64_t mTile = m / shapeM; + int64_t nTile = n / shapeN; + int64_t kTile = k / shapeK; + + // verify shape of aVector + if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA))) { + return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile + << " x " << numElementA << ")"; + } + + // verify shape of bVector + if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB))) { + return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile + << " x " << numElementB << ")"; + } + + // verify shape of cVector + if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC))) { + return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile + << " x " << numElementC << ")"; + } + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" Index: mlir/test/Dialect/NVGPU/invalid.mlir =================================================================== --- mlir/test/Dialect/NVGPU/invalid.mlir +++ mlir/test/Dialect/NVGPU/invalid.mlir @@ -1,4 +1,72 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s +func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected 256 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected 128 warp-wide matrix B elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> { + // expected-error @+1 {{expected 128 warp-wide matrix C elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16> + return %d : vector<2x4xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{expected 128 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- + +func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- + +func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // expected-error @+1 {{expected 32 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + return %d : vector<1x2xf64> +} +// ----- + +func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> { + // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64> + return %d : vector<2x1xf64> +} +// ----- + +func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // expected-error @+1 {{expected 256 warp-wide matrix B elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} +// ----- +func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // expected-error @+1 {{expected input data type (i4,i8,f16,bf16,tf32,f64) supported by nvgpu.mma.sync}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} +// ----- func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () { // expected-error @+1 {{destination memref must have memory space 3}}