diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUBase.td @@ -68,6 +68,9 @@ def GPU_MMAMatrix : DialectType< GPU_Dialect, IsMMAMatrixTypePred, "MMAMatrix type">; +// Memref type acceptable to gpu.subgroup_mma_{load|store}_matrix ops. +def GPU_MMAMemRef : MemRefOf<[F16, F32, VectorOfRankAndType<[1], [F16, F32]>]>; + class MMAMatrixOf allowedTypes> : ContainerType, IsMMAMatrixTypePred, "$_self.cast<::mlir::gpu::MMAMatrixType>().getElementType()", diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -1114,7 +1114,7 @@ ``` }]; - let arguments = (ins Arg, "", [MemRead]>:$srcMemref, + let arguments = (ins Arg:$srcMemref, Variadic:$indices, IndexAttr:$leadDimension); @@ -1153,7 +1153,7 @@ }]; let arguments = (ins Arg>:$src, - Arg, "",[MemWrite]>:$dstMemref, + Arg:$dstMemref, Variadic:$indices, IndexAttr:$leadDimension); diff --git a/mlir/test/Dialect/GPU/invalid.mlir b/mlir/test/Dialect/GPU/invalid.mlir --- a/mlir/test/Dialect/GPU/invalid.mlir +++ b/mlir/test/Dialect/GPU/invalid.mlir @@ -515,6 +515,14 @@ // ----- +func.func @mma_invalid_memref_type(%src: memref<32x4xvector<4x8xf32>>, %i: index) { + // expected-error @+1 {{operand #0 must be memref of 16-bit float or 32-bit float or vector of 16-bit float or 32-bit float values of ranks 1 values}} + %0 = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4x8xf32>> -> !gpu.mma_matrix<16x16xf16, "AOp"> + return +} + +// ----- + #layout_map_col_major = affine_map<(i, j) -> (j, i)> func.func @wmmaStoreOp_invalid_map(%arg0 : !gpu.mma_matrix<16x16xf16, "COp">) -> () { diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -265,8 +265,8 @@ return } - func.func @mmamatrix_valid_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){ - // CHECK-LABEL: func @mmamatrix_valid_element_type + func.func @mmamatrix_valid_scalar_element_type(%src : memref<32x32xf16, affine_map<(d0, d1) -> (d0 * 64 + d1)>>){ + // CHECK-LABEL: func @mmamatrix_valid_scalar_element_type %wg = memref.alloca() {alignment = 32} : memref<32x32xf16, 3> // CHECK: %[[wg:.*]] = memref.alloca() %i = arith.constant 16 : index @@ -285,6 +285,15 @@ return } + // CHECK-LABEL: func @mmamatrix_valid_vector_element_type + func.func @mmamatrix_valid_vector_element_type(%src : memref<32x4xvector<4xf32>>, %i : index) { + // CHECK: gpu.subgroup_mma_load_matrix + %s = gpu.subgroup_mma_load_matrix %src[%i, %i] {leadDimension = 4 : index} : memref<32x4xvector<4xf32>> -> !gpu.mma_matrix<16x16xf16, "COp"> + // CHECK: gpu.subgroup_mma_store_matrix + gpu.subgroup_mma_store_matrix %s, %src[%i, %i] {leadDimension = 4 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<32x4xvector<4xf32>> + return + } + // CHECK-LABEL: func @set_default_device func.func @set_default_device(%arg0: i32) { // CHECK: gpu.set_default_device