diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3027,10 +3027,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Composite : - AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; + AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, - SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix ]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -138,6 +138,10 @@ Type getElementType(unsigned) const; + /// Return true if the number of elements is not known at compile time or is + /// implementation dependent. + bool hasDynamicNumElements() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, @@ -334,7 +338,7 @@ // SPIR-V cooperative matrix type class CooperativeMatrixNVType - : public Type::TypeBase { public: using Base::Base; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -418,7 +418,9 @@ for (auto index : indices) { if (auto cType = type.dyn_cast()) { - if (index < 0 || static_cast(index) >= cType.getNumElements()) { + if (!cType.hasDynamicNumElements() && + (index < 0 || + static_cast(index) >= cType.getNumElements())) { emitErrorFn("index ") << index << " out of bounds for " << type; return nullptr; } @@ -1098,7 +1100,8 @@ << type; } - if (operands.size() != cType.getNumElements()) { + if (!cType.hasDynamicNumElements() && + operands.size() != cType.getNumElements()) { return parser.emitError(loc, "has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << operands.size(); } @@ -1107,8 +1110,8 @@ // also be vectors with the same component type as the Result Type component // type". SmallVector elementTypes; - elementTypes.reserve(cType.getNumElements()); - for (auto index : llvm::seq(0, cType.getNumElements())) { + elementTypes.reserve(operands.size()); + for (auto index : llvm::seq(0, operands.size())) { elementTypes.push_back(cType.getElementType(index)); } state.addTypes(type); @@ -1126,7 +1129,8 @@ auto cType = compositeConstructOp.getType().cast(); SmallVector constituents(compositeConstructOp.constituents()); - if (constituents.size() != cType.getNumElements()) { + if (!cType.hasDynamicNumElements() && + constituents.size() != cType.getNumElements()) { return compositeConstructOp.emitError( "has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -196,8 +196,8 @@ case spirv::TypeKind::Array: return cast().getNumElements(); case spirv::TypeKind::CooperativeMatrix: - return cast().getRows() * - cast().getColumns(); + llvm_unreachable( + "invalid to query number of elements of spirv::CooperativeMatrix type"); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -210,6 +210,16 @@ } } +bool CompositeType::hasDynamicNumElements() const { + switch (getKind()) { + case TypeKind::CooperativeMatrix: + case TypeKind::RuntimeArray: + return true; + default: + return false; + } +} + void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -91,4 +91,12 @@ %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> spv.Return } + + // CHECK-LABEL: @cooperative_matrix_access_chain + spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, StorageBuffer>) "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, StorageBuffer> + %1 = spv.AccessChain %a[%0] : !spv.ptr, StorageBuffer> + spv.Return + } } diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -20,6 +20,14 @@ // ----- +func @composite_construct_coopmatrix(%arg0: f32, %arg1: f32, %arg2 : f32, %arg3 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}}, {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeConstruct %arg0, %arg1, %arg2, %arg3 : !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + func @composite_construct_empty_struct() -> !spv.struct<> { // CHECK: spv.CompositeConstruct : !spv.struct<> %0 = spv.CompositeConstruct : !spv.struct<> @@ -80,6 +88,14 @@ // ----- +func @composite_extract_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeExtract %arg0[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + return %0 : f32 +} + +// ----- + func @composite_extract_no_ssa_operand() -> () { // expected-error @+1 {{expected SSA operand}} %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> @@ -200,6 +216,14 @@ // ----- +func @composite_insert_coopmatrix(%arg0: !spv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spv.coopmatrix<8x16xi32, Subgroup> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + return %0: !spv.coopmatrix<8x16xi32, Subgroup> +} + +// ----- + func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { // expected-error @+1 {{expected at least one index}} %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32> diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -91,3 +91,13 @@ %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> spv.Return } + +// ----- + +// CHECK-LABEL: @cooperative_matrix_access_chain +spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, StorageBuffer>) "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, StorageBuffer> + %1 = spv.AccessChain %a[%0] : !spv.ptr, StorageBuffer> + spv.Return +}