Index: mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3027,10 +3027,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Composite : - AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; + AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, - SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix ]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; Index: mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h =================================================================== --- mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -138,6 +138,10 @@ Type getElementType(unsigned) const; + /// Return true if the number of elements is not known at compile time or is + /// implementation dependent. + bool hasDynamicNumElements() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, @@ -334,7 +338,7 @@ // SPIR-V cooperative matrix type class CooperativeMatrixNVType - : public Type::TypeBase { public: using Base::Base; Index: mlir/lib/Dialect/SPIRV/SPIRVOps.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -418,7 +418,9 @@ for (auto index : indices) { if (auto cType = type.dyn_cast()) { - if (index < 0 || static_cast(index) >= cType.getNumElements()) { + if (!cType.hasDynamicNumElements() && + (index < 0 || + static_cast(index) >= cType.getNumElements())) { emitErrorFn("index ") << index << " out of bounds for " << type; return nullptr; } @@ -1098,7 +1100,8 @@ << type; } - if (operands.size() != cType.getNumElements()) { + if (!cType.hasDynamicNumElements() && + operands.size() != cType.getNumElements()) { return parser.emitError(loc, "has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << operands.size(); } @@ -1107,8 +1110,8 @@ // also be vectors with the same component type as the Result Type component // type". SmallVector elementTypes; - elementTypes.reserve(cType.getNumElements()); - for (auto index : llvm::seq(0, cType.getNumElements())) { + elementTypes.reserve(operands.size()); + for (auto index : llvm::seq(0, operands.size())) { elementTypes.push_back(cType.getElementType(index)); } state.addTypes(type); @@ -1124,13 +1127,19 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { auto cType = compositeConstructOp.getType().cast(); - SmallVector constituents(compositeConstructOp.constituents()); - if (constituents.size() != cType.getNumElements()) { - return compositeConstructOp.emitError( - "has incorrect number of operands: expected ") - << cType.getNumElements() << ", but provided " - << constituents.size(); + + if (cType.isa()) { + if (constituents.size() != 1) + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << "1, but provided " << constituents.size(); + } else { + if (constituents.size() != cType.getNumElements()) + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); } for (auto index : llvm::seq(0, constituents.size())) { Index: mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp =================================================================== --- mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -196,8 +196,8 @@ case spirv::TypeKind::Array: return cast().getNumElements(); case spirv::TypeKind::CooperativeMatrix: - return cast().getRows() * - cast().getColumns(); + llvm_unreachable( + "invalid to query number of elements of spirv::CooperativeMatrix type"); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -210,6 +210,16 @@ } } +bool CompositeType::hasDynamicNumElements() const { + switch (getKind()) { + case TypeKind::CooperativeMatrix: + case TypeKind::RuntimeArray: + return true; + default: + return false; + } +} + void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { Index: mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir =================================================================== --- mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -91,4 +91,12 @@ %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> spv.Return } + + // CHECK-LABEL: @cooperative_matrix_access_chain + spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, StorageBuffer>) "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, StorageBuffer> + %1 = spv.AccessChain %a[%0] : !spv.ptr, StorageBuffer> + spv.Return + } } Index: mlir/test/Dialect/SPIRV/composite-ops.mlir =================================================================== --- mlir/test/Dialect/SPIRV/composite-ops.mlir +++ mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -20,6 +20,14 @@ // ----- +func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + func @composite_construct_empty_struct() -> !spv.struct<> { // CHECK: spv.CompositeConstruct : !spv.struct<> %0 = spv.CompositeConstruct : !spv.struct<> @@ -52,6 +60,14 @@ // ----- +func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} + %0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// @@ -80,6 +96,14 @@ // ----- +func @composite_extract_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeExtract %arg0[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + return %0 : f32 +} + +// ----- + func @composite_extract_no_ssa_operand() -> () { // expected-error @+1 {{expected SSA operand}} %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> @@ -200,6 +224,14 @@ // ----- +func @composite_insert_coopmatrix(%arg0: !spv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spv.coopmatrix<8x16xi32, Subgroup> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + return %0: !spv.coopmatrix<8x16xi32, Subgroup> +} + +// ----- + func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { // expected-error @+1 {{expected at least one index}} %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32> Index: mlir/test/Dialect/SPIRV/cooperative-matrix.mlir =================================================================== --- mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -91,3 +91,13 @@ %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> spv.Return } + +// ----- + +// CHECK-LABEL: @cooperative_matrix_access_chain +spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, StorageBuffer>) "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, StorageBuffer> + %1 = spv.AccessChain %a[%0] : !spv.ptr, StorageBuffer> + spv.Return +}