diff --git a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td --- a/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td +++ b/mlir/include/mlir/Dialect/NVGPU/IR/NVGPU.td @@ -81,7 +81,10 @@ }]; } -def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [NoSideEffect]> { +def NVGPU_MmaSyncOp : NVGPU_Op<"mma.sync", [ + NoSideEffect, + PredOpTrait<"matrixA and matrixB have same element type", TCopVTEtIsSameAs<0, 1>>, + ]> { let description = [{ The `nvgpu.mma.sync` op represents the distributed form of a collective matrix-multiply-and-accumulate (mma) operation that is compatible with @@ -112,6 +115,8 @@ `(` $matrixA`,` $matrixB`,` $matrixC `)` attr-dict `:` `(` type($matrixA) `,` type($matrixB) `,` type($matrixC) `)` `->` type($res) }]; + + let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp --- a/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp +++ b/mlir/lib/Dialect/NVGPU/IR/NVGPUDialect.cpp @@ -88,5 +88,103 @@ return success(); } +LogicalResult MmaSyncOp::verify() { + + // Fundamental tensor core mma.sync op + // For F32 (TF32), F16, S8, and S4 data types fundamental tensor core + // operation is of shape: 8-by-8-by-128b. F64 is an exception. The + // verification for mma.sync covering various shapes and data types is based + // on the fundamental tensor core operionation. + constexpr int kThreads = 32; // 32 threads per warp + int64_t shapeM = 8; + int64_t shapeN = 8; + int64_t shapeK; // set based on data type (128b for all data types except F64) + + // Number of elements A, B, and C per thread per fundamental tensor core tile + int64_t numElementA; // set based on data type (32b except F64) + int64_t numElementB; // set based on data type (32b except F64) + int64_t numElementC{2}; // two accumulator elements per fundamental tile + + // nvgpu.mma.sync vector operands (per thread) + auto aVector = getMatrixA().getType().cast(); + auto bVector = getMatrixB().getType().cast(); + auto cVector = getMatrixC().getType().cast(); + + // vector shapes + ArrayRef aShape = aVector.getShape(); + ArrayRef bShape = bVector.getShape(); + ArrayRef cShape = cVector.getShape(); + + // vector element type + Type aType = aVector.getElementType(); + + // nvgpu.mma.sync shape (per 32 threads or per warp) + int64_t m = getMmaShape()[0].cast().getInt(); + int64_t n = getMmaShape()[1].cast().getInt(); + int64_t k = getMmaShape()[2].cast().getInt(); + + if (aType.isF64()) { + // exception to 8-by-8-128b fundamental tensor core tile size + shapeK = 4; + numElementA = 1; + numElementB = 1; + } else if (aType.isF32() || aType.isBF16() || aType.isF16() || + aType.isInteger(8) || aType.isInteger(4)) { + // 8-by-8-128b fundamental tensor core tile size + int operandBitwidth = aType.getIntOrFloatBitWidth(); + shapeK = 128 / operandBitwidth; // 128b wide shapeK + numElementA = 32 / operandBitwidth; // 32b wide operand A + numElementB = 32 / operandBitwidth; // 32b wide operand B + } else { + return emitError() << "expected input data type (i4,i8,f16,bf16,tf32,f64) " + "supported by nvgpu.mma.sync"; + } + + // + // Basic verification + // + + // verify warp-wide size for vector a + if (aShape[0] * aShape[1] * kThreads != m * k) + return emitOpError() << "expected " << m * k + << " warp-wide matrix A elements"; + + // verify warp-wide size for vector b + if (bShape[0] * bShape[1] * kThreads != k * n) + return emitOpError() << "expected " << k * n + << " warp-wide matrix B elements"; + + // verify warp-wide size for vector c + if (cShape[0] * cShape[1] * kThreads != m * n) + return emitOpError() << "expected " << m * n + << " warp-wide matrix C elements"; + + // + // Extended verification + // + + // tiles of fundamental tensor core operations + int64_t mTile = m / shapeM; + int64_t nTile = n / shapeN; + int64_t kTile = k / shapeK; + + // verify shape of aVector + if (!((aShape[0] == mTile * kTile) && (aShape[1] == numElementA))) + return emitOpError() << "expected matrix A to be shaped (" << mTile * kTile + << " x " << numElementA << ")"; + + // verify shape of bVector + if (!((bShape[0] == kTile * nTile) && (bShape[1] == numElementB))) + return emitOpError() << "expected matrix B to be shaped (" << kTile * nTile + << " x " << numElementB << ")"; + + // verify shape of cVector + if (!((cShape[0] == mTile * nTile) && (cShape[1] == numElementC))) + return emitOpError() << "expected matrix C to be shaped (" << mTile * nTile + << " x " << numElementC << ")"; + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/NVGPU/IR/NVGPU.cpp.inc" diff --git a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir --- a/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir +++ b/mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir @@ -205,7 +205,7 @@ // ----- // CHECK-LABEL: @m16n8k4_tf32 -func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<4x1xf32>) -> vector<4x1xf32> { +func.func @m16n8k4_tf32(%arg0: vector<2x1xf32>, %arg1: vector<1x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { // The A, B operand should be bitcast to i32 // CHECK: llvm.extractvalue // CHECK: llvm.bitcast {{.*}} : vector<1xf32> to i32 @@ -219,17 +219,22 @@ // CHECK-SAME: multiplicandBPtxType = #nvvm.mma_type // CHECK-SAME: shape = #nvvm.shape // CHECK-SAME: -> !llvm.struct<(f32, f32, f32, f32)> - %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<4x1xf32>) -> vector<4x1xf32> - // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][0] - // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> - // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][1] - // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> - // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][2] - // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> - // CHECK: [[el:%.+]] = llvm.extractvalue [[d]][3] - // CHECK: llvm.bitcast [[el]] : f32 to vector<1xf32> - // CHECK-COUNT-4: llvm.insertvalue {{.*}} : !llvm.array<4 x vector<1xf32>> - return %d : vector<4x1xf32> + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 4]} : (vector<2x1xf32>, vector<1x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][0] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][1] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d00:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d01:%.+]] = llvm.insertelement {{%.+}}, [[d00]][{{.*}}] : vector<2xf32> + + // CHECK: [[undef:%.+]] = llvm.mlir.undef : vector<2xf32> + // CHECK-DAG: llvm.extractvalue [[d]][2] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK-DAG: llvm.extractvalue [[d]][3] : !llvm.struct<(f32, f32, f32, f32)> + // CHECK: [[d10:%.+]] = llvm.insertelement {{%.+}}, [[undef]][{{.*}}] : vector<2xf32> + // CHECK: [[d11:%.+]] = llvm.insertelement {{%.+}}, [[d10]][{{.*}}] : vector<2xf32> + + // CHECK-DAG: llvm.insertvalue [[d01]], {{%.+}}[0] : !llvm.array<2 x vector<2xf32>> + // CHECK-DAG: llvm.insertvalue [[d11]], {{%.+}}[1] : !llvm.array<2 x vector<2xf32>> + return %d : vector<2x2xf32> } // ----- diff --git a/mlir/test/Dialect/NVGPU/invalid.mlir b/mlir/test/Dialect/NVGPU/invalid.mlir --- a/mlir/test/Dialect/NVGPU/invalid.mlir +++ b/mlir/test/Dialect/NVGPU/invalid.mlir @@ -1,4 +1,73 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s +func.func @m16n8k16_fp16_vector_shape_a(%arg0: vector<4x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected 256 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_b(%arg0: vector<4x2xf16>, %arg1: vector<2x4xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected 128 warp-wide matrix B elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x4xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_c(%arg0: vector<4x2xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x4xf16>) -> vector<2x4xf16> { + // expected-error @+1 {{expected 128 warp-wide matrix C elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<4x2xf16>, vector<2x2xf16>, vector<2x4xf16>) -> vector<2x4xf16> + return %d : vector<2x4xf16> +} +// ----- + +func.func @m16n8k16_fp16_vector_shape_a_extended(%arg0: vector<2x4xf16>, %arg1: vector<2x2xf16>, %arg2: vector<2x2xf16>) -> vector<2x2xf16> { + // expected-error @+1 {{expected matrix A to be shaped (4 x 2)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 16]} : (vector<2x4xf16>, vector<2x2xf16>, vector<2x2xf16>) -> vector<2x2xf16> + return %d : vector<2x2xf16> +} +// ----- + +func.func @m16n8k8_fp32_vector_shape_a(%arg0: vector<4x2xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{expected 128 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<4x2xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- + +func.func @m16n8k8_fp32_vector_shape_a_extended(%arg0: vector<1x4xf32>, %arg1: vector<2x1xf32>, %arg2: vector<2x2xf32>) -> vector<2x2xf32> { + // expected-error @+1 {{expected matrix A to be shaped (4 x 1)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 8]} : (vector<1x4xf32>, vector<2x1xf32>, vector<2x2xf32>) -> vector<2x2xf32> + return %d : vector<2x2xf32> +} +// ----- + +func.func @m8n8k4_fp64_vector_shape_a(%arg0: vector<1x2xf64>, %arg1: vector<1x1xf64>, %arg2: vector<1x2xf64>) -> vector<1x2xf64> { + // expected-error @+1 {{expected 32 warp-wide matrix A elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x2xf64>, vector<1x1xf64>, vector<1x2xf64>) -> vector<1x2xf64> + return %d : vector<1x2xf64> +} +// ----- + +func.func @m8n8k4_fp64_vector_shape_c_extended(%arg0: vector<1x1xf64>, %arg1: vector<1x1xf64>, %arg2: vector<2x1xf64>) -> vector<2x1xf64> { + // expected-error @+1 {{expected matrix C to be shaped (1 x 2)}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [8, 8, 4]} : (vector<1x1xf64>, vector<1x1xf64>, vector<2x1xf64>) -> vector<2x1xf64> + return %d : vector<2x1xf64> +} +// ----- + +func.func @m16n8k32_int8_vector_shape_b(%arg0: vector<4x4xi8>, %arg1: vector<4x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // expected-error @+1 {{expected 256 warp-wide matrix B elements}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi8>, vector<4x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} +// ----- + +func.func @m16n8k32_int32_datatype(%arg0: vector<4x4xi32>, %arg1: vector<2x4xi8>, %arg2: vector<2x2xi32>) -> vector<2x2xi32> { + // expected-error @+1 {{op failed to verify that matrixA and matrixB have same element type}} + %d = nvgpu.mma.sync (%arg0, %arg1, %arg2) {mmaShape = [16, 8, 32]} : (vector<4x4xi32>, vector<2x4xi8>, vector<2x2xi32>) -> vector<2x2xi32> + return %d : vector<2x2xi32> +} +// ----- func.func @async_cp_memory_space(%dst : memref<16xf32>, %src : memref<16xf32>, %i : index) -> () { // expected-error @+1 {{destination memref must have memory space 3}}