diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCooperativeMatrixOps.td @@ -101,16 +101,17 @@ ``` {.ebnf} cooperative-matrixload-op ::= ssa-id `=` `spv.CooperativeMatrixLoadNV` - storage-class ssa-use `,` ssa-use `,` ssa-use + ssa-use `,` ssa-use `,` ssa-use (`[` memory-access `]`)? ` : ` + pointer-type `as` cooperative-matrix-type ``` For example: ``` - %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %colMajor - : !spv.coopmatrix + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %colMajor + : !spv.ptr as !spv.coopmatrix ``` }]; @@ -243,16 +244,17 @@ ``` {.ebnf} coop-matrix-store-op ::= `spv.CooperativeMatrixStoreNV ` - storage-class ssa-use `, ` ssa-use `, ` ssa-use `, ` ssa-use `, ` - (`[` memory-access `]`)? `:` spirv-element-type + ssa-use `, ` ssa-use `, ` + (`[` memory-access `]`)? `:` + pointer-type `,` spirv-element-type ``` For example: ``` - spv.CooperativeMatrixStoreNV "StorageBuffer" %arg0, %arg2, %arg1, %arg3 : - !spv.coopmatrix + spv.CooperativeMatrixStoreNV %arg0, %arg2, %arg1, %arg3 : + !spv.ptr, !spv.coopmatrix ``` }]; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2793,21 +2793,16 @@ static ParseResult parseCooperativeMatrixLoadNVOp(OpAsmParser &parser, OperationState &state) { - spirv::StorageClass storageClass; SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type ptrType; Type elementType; - if (parseEnumStrAttr(storageClass, parser) || - parser.parseOperandList(operandInfo, 3) || + if (parser.parseOperandList(operandInfo, 3) || parseMemoryAccessAttributes(parser, state) || parser.parseColon() || - parser.parseType(elementType)) { + parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { return failure(); } - - auto ptrType = spirv::PointerType::get( - elementType.cast().getElementType(), - storageClass); SmallVector OperandType = {ptrType, strideType, columnMajorType}; if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(), state.operands)) { @@ -2819,25 +2814,30 @@ } static void print(spirv::CooperativeMatrixLoadNVOp M, OpAsmPrinter &printer) { - StringRef sc = stringifyStorageClass( - M.pointer().getType().cast().getStorageClass()); - printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " \"" << sc - << "\" " << M.pointer() << ", " << M.stride() << ", " - << M.columnmajor(); + printer << spirv::CooperativeMatrixLoadNVOp::getOperationName() << " " + << M.pointer() << ", " << M.stride() << ", " << M.columnmajor(); // Print optional memory access attribute. if (auto memAccess = M.memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << M.getType(); + printer << " : " << M.pointer().getType() << " as " << M.getType(); } static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, Type coopMatrix) { - if (pointer.cast().getPointeeType() != - coopMatrix.cast().getElementType()) + Type pointeeType = pointer.cast().getPointeeType(); + if (!pointeeType.isa() && !pointeeType.isa()) return op->emitError( - "expected the same type for pointer and the cooperative matrix" - "element, bu provided ") - << pointer << " and " << coopMatrix; + "Pointer must point to a scalar or vector type but provided ") + << pointeeType; + spirv::StorageClass storage = + pointer.cast().getStorageClass(); + if (storage != spirv::StorageClass::Workgroup && + storage != spirv::StorageClass::StorageBuffer && + storage != spirv::StorageClass::PhysicalStorageBuffer) + return op->emitError( + "Pointer storage class must be Workgroup, StorageBuffer or " + "PhysicalStorageBufferEXT but provided ") + << stringifyStorageClass(storage); return success(); } @@ -2847,21 +2847,17 @@ static ParseResult parseCooperativeMatrixStoreNVOp(OpAsmParser &parser, OperationState &state) { - spirv::StorageClass storageClass; SmallVector operandInfo; Type strideType = parser.getBuilder().getIntegerType(32); Type columnMajorType = parser.getBuilder().getIntegerType(1); + Type ptrType; Type elementType; - if (parseEnumStrAttr(storageClass, parser) || - parser.parseOperandList(operandInfo, 4) || + if (parser.parseOperandList(operandInfo, 4) || parseMemoryAccessAttributes(parser, state) || parser.parseColon() || + parser.parseType(ptrType) || parser.parseComma() || parser.parseType(elementType)) { return failure(); } - - auto ptrType = spirv::PointerType::get( - elementType.cast().getElementType(), - storageClass); SmallVector OperandType = {ptrType, elementType, strideType, columnMajorType}; if (parser.resolveOperands(operandInfo, OperandType, parser.getNameLoc(), @@ -2874,17 +2870,14 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix, OpAsmPrinter &printer) { - StringRef sc = stringifyStorageClass(coopMatrix.pointer() - .getType() - .cast() - .getStorageClass()); - printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " \"" - << sc << "\" " << coopMatrix.pointer() << ", " << coopMatrix.object() - << ", " << coopMatrix.stride() << ", " << coopMatrix.columnmajor(); + printer << spirv::CooperativeMatrixStoreNVOp::getOperationName() << " " + << coopMatrix.pointer() << ", " << coopMatrix.object() << ", " + << coopMatrix.stride() << ", " << coopMatrix.columnmajor(); // Print optional memory access attribute. if (auto memAccess = coopMatrix.memory_access()) printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; - printer << " : " << coopMatrix.getOperand(1).getType(); + printer << " : " << coopMatrix.pointer().getType() << ", " + << coopMatrix.getOperand(1).getType(); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -3,29 +3,29 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK-LABEL: @cooperative_matrix_load spv.func @cooperative_matrix_load(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> - %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr as !spv.coopmatrix<16x8xi32, Workgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr as !spv.coopmatrix<16x8xi32, Workgroup> spv.Return } // CHECK-LABEL: @cooperative_matrix_load_memaccess spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> - %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr as !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr as !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } // CHECK-LABEL: @cooperative_matrix_store spv.func @cooperative_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.coopmatrix<16x8xi32, Workgroup>, %b : i1) "None" { - // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> - spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr, !spv.coopmatrix<16x8xi32, Workgroup> + spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr, !spv.coopmatrix<16x8xi32, Workgroup> spv.Return } // CHECK-LABEL: @cooperative_matrix_store_memaccess spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> - spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr, !spv.coopmatrix<8x16xi32, Subgroup> + spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr, !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -2,30 +2,37 @@ // CHECK-LABEL: @cooperative_matrix_load spv.func @cooperative_matrix_load(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<16x8xi32, Workgroup> - %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b : !spv.coopmatrix<16x8xi32, Workgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr as !spv.coopmatrix<16x8xi32, Workgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr as !spv.coopmatrix<16x8xi32, Workgroup> spv.Return } // ----- // CHECK-LABEL: @cooperative_matrix_load_memaccess spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { - // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> - %0 = spv.CooperativeMatrixLoadNV "StorageBuffer" %ptr, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr as !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr as !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// CHECK-LABEL: @cooperative_matrix_load_diff_ptr_type +spv.func @cooperative_matrix_load_diff_ptr_type(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { + // CHECK: {{%.*}} = spv.CooperativeMatrixLoadNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b ["Volatile"] : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } // CHECK-LABEL: @cooperative_matrix_store spv.func @cooperative_matrix_store(%ptr : !spv.ptr, %stride : i32, %m : !spv.coopmatrix<8x16xi32, Workgroup>, %b : i1) "None" { - // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xi32, Workgroup> - spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b : !spv.coopmatrix<8x16xi32, Workgroup> + // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} : !spv.ptr, !spv.coopmatrix<8x16xi32, Workgroup> + spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b : !spv.ptr, !spv.coopmatrix<8x16xi32, Workgroup> spv.Return } // CHECK-LABEL: @cooperative_matrix_store_memaccess spv.func @cooperative_matrix_store_memaccess(%ptr : !spv.ptr, %m : !spv.coopmatrix<8x16xi32, Subgroup>, %stride : i32, %b : i1) "None" { - // CHECK: spv.CooperativeMatrixStoreNV "StorageBuffer" {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> - spv.CooperativeMatrixStoreNV "StorageBuffer" %ptr, %m, %stride, %b ["Volatile"] : !spv.coopmatrix<8x16xi32, Subgroup> + // CHECK: spv.CooperativeMatrixStoreNV {{%.*}}, {{%.*}}, {{%.*}} ["Volatile"] : !spv.ptr, !spv.coopmatrix<8x16xi32, Subgroup> + spv.CooperativeMatrixStoreNV %ptr, %m, %stride, %b ["Volatile"] : !spv.ptr, !spv.coopmatrix<8x16xi32, Subgroup> spv.Return } @@ -134,3 +141,18 @@ spv.Return } +// ----- + +spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, StorageBuffer>, %stride : i32, %b : i1) "None" { + // expected-error @+1 {{Pointer must point to a scalar or vector type}} + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr, StorageBuffer> as !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +} + +// ----- + +spv.func @cooperative_matrix_load_memaccess(%ptr : !spv.ptr, %stride : i32, %b : i1) "None" { + // expected-error @+1 {{Pointer storage class must be Workgroup, StorageBuffer or PhysicalStorageBufferEXT}} + %0 = spv.CooperativeMatrixLoadNV %ptr, %stride, %b : !spv.ptr as !spv.coopmatrix<8x16xi32, Subgroup> + spv.Return +}