diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4119,6 +4119,9 @@ AnyTypeOf<[type, VectorOfLengthAndType<[2, 3, 4, 8, 16], [type]>, SPIRV_CoopMatrixOfType<[type]>]>; +class SPIRV_MatrixOrCoopMatrixOf : + AnyTypeOf<[SPIRV_AnyMatrix, SPIRV_CoopMatrixOfType<[type]>]>; + def SPIRV_ScalarOrVector : AnyTypeOf<[SPIRV_Scalar, SPIRV_Vector]>; def SPIRV_ScalarOrVectorOrPtr : AnyTypeOf<[SPIRV_ScalarOrVector, SPIRV_AnyPtr]>; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -70,7 +70,8 @@ // ----- -def SPIRV_MatrixTimesScalarOp : SPIRV_Op<"MatrixTimesScalar", [Pure]> { +def SPIRV_MatrixTimesScalarOp : SPIRV_Op< + "MatrixTimesScalar", [Pure, AllTypesMatch<["matrix", "result"]>]> { let summary = "Scale a floating-point matrix."; let description = [{ @@ -108,18 +109,16 @@ ]; let arguments = (ins - SPIRV_AnyMatrix:$matrix, + SPIRV_MatrixOrCoopMatrixOf:$matrix, SPIRV_Float:$scalar ); let results = (outs - SPIRV_AnyMatrix:$result + SPIRV_MatrixOrCoopMatrixOf:$result ); - // TODO: we need just one matrix type given that the input and result are the - // same and the scalar's type can be deduced from it. let assemblyFormat = [{ - operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result) + operands attr-dict `:` type($matrix) `,` type($scalar) }]; let availability = [ diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -4128,35 +4128,20 @@ //===----------------------------------------------------------------------===// LogicalResult spirv::MatrixTimesScalarOp::verify() { - // We already checked that result and matrix are both of matrix type in the - // auto-generated verify method. - - auto inputMatrix = getMatrix().getType().cast(); - auto resultMatrix = getResult().getType().cast(); + if (auto inputCoopmat = + getMatrix().getType().dyn_cast()) { + if (inputCoopmat.getElementType() != getScalar().getType()) + return emitError("input matrix components' type and scaling value must " + "have the same type"); + return success(); + } // Check that the scalar type is the same as the matrix element type. + auto inputMatrix = getMatrix().getType().cast(); if (getScalar().getType() != inputMatrix.getElementType()) return emitError("input matrix components' type and scaling value must " "have the same type"); - // Note that the next three checks could be done using the AllTypesMatch - // trait in the Op definition file but it generates a vague error message. - - // Check that the input and result matrices have the same columns' count - if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns()) - return emitError("input and result matrices must have the same " - "number of columns"); - - // Check that the input and result matrices' have the same rows count - if (inputMatrix.getNumRows() != resultMatrix.getNumRows()) - return emitError("input and result matrices' columns must have " - "the same size"); - - // Check that the input and result matrices' have the same component type - if (inputMatrix.getElementType() != resultMatrix.getElementType()) - return emitError("input and result matrices' columns must have " - "the same component type"); - return success(); } diff --git a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/matrix-ops.mlir @@ -1,13 +1,20 @@ // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s spirv.module Logical GLSL450 requires #spirv.vce { - // CHECK-LABEL: @matrix_times_scalar - spirv.func @matrix_times_scalar(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>> - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>> + // CHECK-LABEL: @matrix_times_scalar_1 + spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>> } + // CHECK-LABEL: @matrix_times_scalar_2 + spirv.func @matrix_times_scalar_2(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup> + } + // CHECK-LABEL: @matrix_transpose_1 spirv.func @matrix_transpose_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>) -> !spirv.matrix<2 x vector<3xf32>> "None" { // CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>> @@ -39,54 +46,42 @@ // ----- -func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () { +func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f16) { // expected-error @+1 {{input matrix components' type and scaling value must have the same type}} - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16 -> !spirv.matrix<3 x vector<3xf32>> + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f16 + return } // ----- -func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () { +func.func @input_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f64) { // expected-error @+1 {{input matrix components' type and scaling value must have the same type}} - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64 -> !spirv.matrix<3 x vector<3xf32>> -} - -// ----- - -func.func @input_output_component_type_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () { - // expected-error @+1 {{input and result matrices' columns must have the same component type}} - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf64>> -} - -// ----- - -func.func @input_output_size_mismatch(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () { - // expected-error @+1 {{input and result matrices must have the same number of columns}} - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<4 x vector<3xf32>> + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f64 + return } // ----- -func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () { +func.func @transpose_op_shape_mismatch_1(%arg0 : !spirv.matrix<3 x vector<4xf32>>) { // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}} %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<3 x vector<3xf32>> - spirv.Return + return } // ----- -func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () { +func.func @transpose_op_shape_mismatch_2(%arg0 : !spirv.matrix<3 x vector<4xf32>>) { // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}} %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<2 x vector<4xf32>> - spirv.Return + return } // ----- -func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) -> () { +func.func @transpose_op_type_mismatch(%arg0 : !spirv.matrix<3 x vector<4xf32>>) { // expected-error @+1 {{input and output matrices must have the same component type}} %result = spirv.Transpose %arg0 : !spirv.matrix<3 x vector<4xf32>> -> !spirv.matrix<4 x vector<3xf16>> - spirv.Return + return } // ----- @@ -94,6 +89,7 @@ func.func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){ // expected-error @+1 {{right and result matrices must have equal columns' count}} %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<3 x vector<2xf32>> + return } // ----- @@ -101,6 +97,7 @@ func.func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<3xf32>>){ // expected-error @+1 {{left and result matrices must have equal rows' count}} %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<3xf32>> -> !spirv.matrix<2 x vector<3xf32>> + return } // ----- @@ -108,6 +105,7 @@ func.func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spirv.matrix<3 x vector<2xf32>>, %arg1 : !spirv.matrix<2 x vector<2xf32>>){ // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}} %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<2xf32>>, !spirv.matrix<2 x vector<2xf32>> -> !spirv.matrix<2 x vector<2xf32>> + return } // ----- @@ -115,6 +113,7 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){ // expected-error @+1 {{right and result matrices' component type must be the same}} %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf64>> + return } @@ -123,4 +122,5 @@ func.func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spirv.matrix<3 x vector<3xf64>>, %arg1 : !spirv.matrix<3x vector<3xf32>>){ // expected-error @+1 {{left and result matrices' component type must be the same}} %result = spirv.MatrixTimesMatrix %arg0, %arg1 : !spirv.matrix<3 x vector<3xf64>>, !spirv.matrix<3 x vector<3xf32>> -> !spirv.matrix<3 x vector<3xf32>> + return } diff --git a/mlir/test/Target/SPIRV/matrix.mlir b/mlir/test/Target/SPIRV/matrix.mlir --- a/mlir/test/Target/SPIRV/matrix.mlir +++ b/mlir/test/Target/SPIRV/matrix.mlir @@ -10,17 +10,23 @@ // CHECK-LABEL: @matrix_times_scalar_1 spirv.func @matrix_times_scalar_1(%arg0 : !spirv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spirv.matrix<3 x vector<3xf32>> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>> - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>> + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf32>>, f32 spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf32>> } // CHECK-LABEL: @matrix_times_scalar_2 spirv.func @matrix_times_scalar_2(%arg0 : !spirv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spirv.matrix<3 x vector<3xf16>> "None" { - // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>> - %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16 -> !spirv.matrix<3 x vector<3xf16>> + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf16>>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.matrix<3 x vector<3xf16>>, f16 spirv.ReturnValue %result : !spirv.matrix<3 x vector<3xf16>> + } + // CHECK-LABEL: @matrix_times_scalar_3 + spirv.func @matrix_times_scalar_3(%arg0 : !spirv.coopmatrix<16x16xf16, Subgroup>, %arg1 : f16) -> !spirv.coopmatrix<16x16xf16, Subgroup> "None" { + // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + %result = spirv.MatrixTimesScalar %arg0, %arg1 : !spirv.coopmatrix<16x16xf16, Subgroup>, f16 + spirv.ReturnValue %result : !spirv.coopmatrix<16x16xf16, Subgroup> } // CHECK-LABEL: @matrix_transpose_1