diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3027,10 +3027,12 @@ def SPV_Scalar : AnyTypeOf<[SPV_Numerical, SPV_Bool]>; def SPV_Aggregate : AnyTypeOf<[SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; def SPV_Composite : - AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct]>; + AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, - SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct + SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, + SPV_AnyCooperativeMatrix ]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -134,10 +134,16 @@ /// Returns true if the given vector type is valid for the SPIR-V dialect. static bool isValid(VectorType); + /// Return the number of elements of the type. This should only be called if + /// hasCompileTimeKnownNumElements is true. unsigned getNumElements() const; Type getElementType(unsigned) const; + /// Return true if the number of elements is known at compile time and is not + /// implementation dependent. + bool hasCompileTimeKnownNumElements() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, @@ -334,7 +340,7 @@ // SPIR-V cooperative matrix type class CooperativeMatrixNVType - : public Type::TypeBase { public: using Base::Base; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -418,7 +418,9 @@ for (auto index : indices) { if (auto cType = type.dyn_cast()) { - if (index < 0 || static_cast(index) >= cType.getNumElements()) { + if (cType.hasCompileTimeKnownNumElements() && + (index < 0 || + static_cast(index) >= cType.getNumElements())) { emitErrorFn("index ") << index << " out of bounds for " << type; return nullptr; } @@ -1098,7 +1100,8 @@ << type; } - if (operands.size() != cType.getNumElements()) { + if (cType.hasCompileTimeKnownNumElements() && + operands.size() != cType.getNumElements()) { return parser.emitError(loc, "has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << operands.size(); } @@ -1107,8 +1110,8 @@ // also be vectors with the same component type as the Result Type component // type". SmallVector elementTypes; - elementTypes.reserve(cType.getNumElements()); - for (auto index : llvm::seq(0, cType.getNumElements())) { + elementTypes.reserve(operands.size()); + for (auto index : llvm::seq(0, operands.size())) { elementTypes.push_back(cType.getElementType(index)); } state.addTypes(type); @@ -1124,13 +1127,19 @@ static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { auto cType = compositeConstructOp.getType().cast(); - SmallVector constituents(compositeConstructOp.constituents()); - if (constituents.size() != cType.getNumElements()) { - return compositeConstructOp.emitError( - "has incorrect number of operands: expected ") - << cType.getNumElements() << ", but provided " - << constituents.size(); + + if (cType.isa()) { + if (constituents.size() != 1) + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << "1, but provided " << constituents.size(); + } else { + if (constituents.size() != cType.getNumElements()) + return compositeConstructOp.emitError( + "has incorrect number of operands: expected ") + << cType.getNumElements() << ", but provided " + << constituents.size(); } for (auto index : llvm::seq(0, constituents.size())) { diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -196,8 +196,8 @@ case spirv::TypeKind::Array: return cast().getNumElements(); case spirv::TypeKind::CooperativeMatrix: - return cast().getRows() * - cast().getColumns(); + llvm_unreachable( + "invalid to query number of elements of spirv::CooperativeMatrix type"); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -210,6 +210,16 @@ } } +bool CompositeType::hasCompileTimeKnownNumElements() const { + switch (getKind()) { + case TypeKind::CooperativeMatrix: + case TypeKind::RuntimeArray: + return false; + default: + return true; + } +} + void CompositeType::getExtensions( SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { diff --git a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/cooperative-matrix.mlir @@ -91,4 +91,12 @@ %r = spv.FDiv %a, %b : !spv.coopmatrix<8x16xf32, Subgroup> spv.Return } + + // CHECK-LABEL: @cooperative_matrix_access_chain + spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + spv.ReturnValue %1 : !spv.ptr + } } diff --git a/mlir/test/Dialect/SPIRV/composite-ops.mlir b/mlir/test/Dialect/SPIRV/composite-ops.mlir --- a/mlir/test/Dialect/SPIRV/composite-ops.mlir +++ b/mlir/test/Dialect/SPIRV/composite-ops.mlir @@ -20,6 +20,14 @@ // ----- +func @composite_construct_coopmatrix(%arg0 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // CHECK: spv.CompositeConstruct {{%.*}} : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeConstruct %arg0 : !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + func @composite_construct_empty_struct() -> !spv.struct<> { // CHECK: spv.CompositeConstruct : !spv.struct<> %0 = spv.CompositeConstruct : !spv.struct<> @@ -52,6 +60,14 @@ // ----- +func @composite_construct_coopmatrix(%arg0 : f32, %arg1 : f32) -> !spv.coopmatrix<8x16xf32, Subgroup> { + // expected-error @+1 {{has incorrect number of operands: expected 1, but provided 2}} + %0 = spv.CompositeConstruct %arg0, %arg1 : !spv.coopmatrix<8x16xf32, Subgroup> + return %0: !spv.coopmatrix<8x16xf32, Subgroup> +} + +// ----- + //===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// @@ -80,6 +96,14 @@ // ----- +func @composite_extract_coopmatrix(%arg0 : !spv.coopmatrix<8x16xf32, Subgroup>) -> f32 { + // CHECK: {{%.*}} = spv.CompositeExtract {{%.*}}[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + %0 = spv.CompositeExtract %arg0[2 : i32] : !spv.coopmatrix<8x16xf32, Subgroup> + return %0 : f32 +} + +// ----- + func @composite_extract_no_ssa_operand() -> () { // expected-error @+1 {{expected SSA operand}} %0 = spv.CompositeExtract [4 : i32, 1 : i32] : !spv.array<4x!spv.array<4xf32>> @@ -200,6 +224,14 @@ // ----- +func @composite_insert_coopmatrix(%arg0: !spv.coopmatrix<8x16xi32, Subgroup>, %arg1: i32) -> !spv.coopmatrix<8x16xi32, Subgroup> { + // CHECK: {{%.*}} = spv.CompositeInsert {{%.*}}, {{%.*}}[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + %0 = spv.CompositeInsert %arg1, %arg0[5 : i32] : i32 into !spv.coopmatrix<8x16xi32, Subgroup> + return %0: !spv.coopmatrix<8x16xi32, Subgroup> +} + +// ----- + func @composite_insert_no_indices(%arg0: !spv.array<4xf32>, %arg1: f32) -> !spv.array<4xf32> { // expected-error @+1 {{expected at least one index}} %0 = spv.CompositeInsert %arg1, %arg0[] : f32 into !spv.array<4xf32> diff --git a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir --- a/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir +++ b/mlir/test/Dialect/SPIRV/cooperative-matrix.mlir @@ -94,6 +94,16 @@ // ----- +// CHECK-LABEL: @cooperative_matrix_access_chain +spv.func @cooperative_matrix_access_chain(%a : !spv.ptr, Function>) -> !spv.ptr "None" { + %0 = spv.constant 0: i32 + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr, Function> + %1 = spv.AccessChain %a[%0] : !spv.ptr, Function> + spv.ReturnValue %1 : !spv.ptr +} + +// ----- + spv.func @cooperative_matrix_muladd(%a : !spv.coopmatrix<16x16xi32, Subgroup>, %b : !spv.coopmatrix<16x8xi32, Subgroup>, %c : !spv.coopmatrix<8x8xi32, Subgroup>) "None" { // expected-error @+1 {{'spv.CooperativeMatrixMulAddNV' op matrix size must match}} %r = spv.CooperativeMatrixMulAddNV %a, %b, %c : !spv.coopmatrix<16x16xi32, Subgroup>, !spv.coopmatrix<16x8xi32, Subgroup> -> !spv.coopmatrix<8x8xi32, Subgroup>