diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3166,6 +3166,7 @@ def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; +def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; @@ -3273,38 +3274,38 @@ SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, - SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, - SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, - SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine, - SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, - SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, - SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, - SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, - SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, - SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, - SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, - SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, - SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, - SPV_OC_OpCooperativeMatrixLengthNV + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, + SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, + SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, + SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, + SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, + SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, + SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, + SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, + SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, + SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, + SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, + SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, + SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, + SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, + SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, + SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak, + SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, + SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, + SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, + SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, + SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, + SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, + SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, + SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, + SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, + SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, + SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, + SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, + SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, + SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, + SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td @@ -12,10 +12,65 @@ #ifndef SPIRV_MATRIX_OPS #define SPIRV_MATRIX_OPS +include "mlir/Interfaces/SideEffectInterfaces.td" // ----- -def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> { +def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> { + let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix."; + + let description = [{ + Result Type must be an OpTypeMatrix whose Column Type is a vector of + floating-point type. + + LeftMatrix must be a matrix whose Column Type is the same as the Column + Type in Result Type. + + RightMatrix must be a matrix with the same Component Type as the + Component Type in Result Type. Its number of columns must equal the + number of columns in Result Type. Its columns must have the same number + of components as the number of columns in LeftMatrix. + + + + ``` + matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use, + ssa-use `:` matrix-type `,` matrix-type `->` matrix-type + ```mlir + + #### Example: + + ``` + %0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 : + !spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> -> + !spv.matrix<4 x vector<4xf32>> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Matrix]> + ]; + + let arguments = (ins + SPV_AnyMatrix:$leftmatrix, + SPV_AnyMatrix:$rightmatrix + ); + + let results = (outs + SPV_AnyMatrix:$result + ); + let assemblyFormat = [{ + operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result) + }]; + let verifier = [{ return verifyMatrixTimesMatrix(*this); }]; +} + +// ----- + +def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> { let summary = "Scale a floating-point matrix."; let description = [{ @@ -79,7 +134,7 @@ // ----- -def SPV_TransposeOp : SPV_Op<"Transpose", []> { +def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> { let summary = "Transpose a matrix."; let description = [{ diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVTypes.h @@ -410,13 +410,23 @@ Type columnType, uint32_t columnCount); - /// Returns true if the matrix elements are vectors of float elements + /// Returns true if the matrix elements are vectors of float elements. static bool isValidColumnType(Type columnType); - Type getElementType() const; + Type getColumnType() const; + + /// Returns the number of rows. + unsigned getNumRows() const; + + /// Returns the number of columns. + unsigned getNumColumns() const; + /// Returns total number of elements (rows*columns). unsigned getNumElements() const; + /// Returns the elements' type (i.e, single element type). + Type getElementType() const; + void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage = llvm::None); void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities, diff --git a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVDialect.cpp @@ -723,7 +723,7 @@ } static void print(MatrixType type, DialectAsmPrinter &os) { - os << "matrix<" << type.getNumElements() << " x " << type.getElementType(); + os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); os << ">"; } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2779,37 +2779,30 @@ // auto-generated verify method. auto inputMatrix = op.matrix().getType().cast(); - // Check that the scalar type is the same as the matrix components type. - if (auto inputMatrixColumns = - inputMatrix.getElementType().dyn_cast()) { - if (op.scalar().getType() != inputMatrixColumns.getElementType()) - return op.emitError("input matrix components' type and scaling " - "value must have the same type"); - - // Note that the next three checks could be done using the AllTypesMatch - // trait in the Op definition file but it generates a vague error message. - - // Check that the input and result matrices have the same size - auto resultMatrix = op.result().getType().cast(); - if (inputMatrix.getNumElements() != resultMatrix.getNumElements()) - return op.emitError("input and result matrices must have " - "the same number of columns"); - - if (auto resultMatrixColumns = - resultMatrix.getElementType().dyn_cast()) { - // Check that the input and result matrices' columns have the same type - if (inputMatrixColumns.getElementType() != - resultMatrixColumns.getElementType()) - return op.emitError("input and result matrices' columns must " - "have the same component type"); - - // Check that the input and result matrices' columns have the same size - if (inputMatrixColumns.getNumElements() != - resultMatrixColumns.getNumElements()) - return op.emitError("input and result matrices' columns must " - "have the same size"); - } - } + auto resultMatrix = op.result().getType().cast(); + + // Check that the scalar type is the same as the matrix element type. + if (op.scalar().getType() != inputMatrix.getElementType()) + return op.emitError("input matrix components' type and scaling value must " + "have the same type"); + + // Note that the next three checks could be done using the AllTypesMatch + // trait in the Op definition file but it generates a vague error message. + + // Check that the input and result matrices have the same columns' count + if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns()) + return op.emitError("input and result matrices must have the same " + "number of columns"); + + // Check that the input and result matrices' have the same rows count + if (inputMatrix.getNumRows() != resultMatrix.getNumRows()) + return op.emitError("input and result matrices' columns must have " + "the same size"); + + // Check that the input and result matrices' have the same component type + if (inputMatrix.getElementType() != resultMatrix.getElementType()) + return op.emitError("input and result matrices' columns must have " + "the same component type"); return success(); } @@ -2902,24 +2895,56 @@ auto resultMatrix = op.result().getType().cast(); // Verify that the input and output matrices have correct shapes. - if (auto inputMatrixColumns = - inputMatrix.getElementType().dyn_cast()) { - if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements()) - return op.emitError("input matrix rows count must be equal to " - "output matrix columns count"); - if (auto resultMatrixColumns = - resultMatrix.getElementType().dyn_cast()) { - if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements()) - return op.emitError("input matrix columns count must be equal " - "to output matrix rows count"); - - // Verify that the input and output matrices have the same component type - if (inputMatrixColumns.getElementType() != - resultMatrixColumns.getElementType()) - return op.emitError("input and output matrices must have the " - "same component type"); - } - } + if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) + return op.emitError("input matrix rows count must be equal to " + "output matrix columns count"); + + if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) + return op.emitError("input matrix columns count must be equal to " + "output matrix rows count"); + + // Verify that the input and output matrices have the same component type + if (inputMatrix.getElementType() != resultMatrix.getElementType()) + return op.emitError("input and output matrices must have the same " + "component type"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.MatrixTimesMatrix +//===----------------------------------------------------------------------===// + +static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) { + auto leftMatrix = op.leftmatrix().getType().cast(); + auto rightMatrix = op.rightmatrix().getType().cast(); + auto resultMatrix = op.result().getType().cast(); + + // left matrix columns' count and right matrix rows' count must be equal + if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) + return op.emitError("left matrix columns' count must be equal to " + "the right matrix rows' count"); + + // right and result matrices columns' count must be the same + if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) + return op.emitError( + "right and result matrices must have equal columns' count"); + + // right and result matrices component type must be the same + if (rightMatrix.getElementType() != resultMatrix.getElementType()) + return op.emitError("right and result matrices' component type must" + " be the same"); + + // left and result matrices component type must be the same + if (leftMatrix.getElementType() != resultMatrix.getElementType()) + return op.emitError("left and result matrices' component type" + " must be the same"); + + // left and result matrices rows count must be the same + if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) + return op.emitError("left and result matrices must have equal rows'" + " count"); + return success(); } diff --git a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVTypes.cpp @@ -182,7 +182,7 @@ case spirv::TypeKind::CooperativeMatrix: return cast().getElementType(); case spirv::TypeKind::Matrix: - return cast().getElementType(); + return cast().getColumnType(); case spirv::TypeKind::RuntimeArray: return cast().getElementType(); case spirv::TypeKind::Struct: @@ -202,7 +202,7 @@ llvm_unreachable( "invalid to query number of elements of spirv::CooperativeMatrix type"); case spirv::TypeKind::Matrix: - return cast().getNumElements(); + return cast().getNumColumns(); case spirv::TypeKind::RuntimeArray: llvm_unreachable( "invalid to query number of elements of spirv::RuntimeArray type"); @@ -1086,13 +1086,25 @@ return false; } -Type MatrixType::getElementType() const { return getImpl()->columnType; } +Type MatrixType::getColumnType() const { return getImpl()->columnType; } -unsigned MatrixType::getNumElements() const { return getImpl()->columnCount; } +Type MatrixType::getElementType() const { + return getImpl()->columnType.cast().getElementType(); +} + +unsigned MatrixType::getNumColumns() const { return getImpl()->columnCount; } + +unsigned MatrixType::getNumRows() const { + return getImpl()->columnType.cast().getShape()[0]; +} + +unsigned MatrixType::getNumElements() const { + return (getImpl()->columnCount) * getNumRows(); +} void MatrixType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions, Optional storage) { - getElementType().cast().getExtensions(extensions, storage); + getColumnType().cast().getExtensions(extensions, storage); } void MatrixType::getCapabilities( @@ -1104,5 +1116,5 @@ capabilities.push_back(ref); } // Add any capabilities associated with the underlying vectors (i.e., columns) - getElementType().cast().getCapabilities(capabilities, storage); + getColumnType().cast().getCapabilities(capabilities, storage); } diff --git a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp --- a/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Dialect/SPIRV/Serialization/Serializer.cpp @@ -1127,12 +1127,12 @@ if (auto matrixType = type.dyn_cast()) { uint32_t elementTypeID = 0; - if (failed(processType(loc, matrixType.getElementType(), elementTypeID))) { + if (failed(processType(loc, matrixType.getColumnType(), elementTypeID))) { return failure(); } typeEnum = spirv::Opcode::OpTypeMatrix; operands.push_back(elementTypeID); - operands.push_back(matrixType.getNumElements()); + operands.push_back(matrixType.getNumColumns()); return success(); } diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -29,6 +29,20 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>> } + + // CHECK-LABEL: @matrix_times_matrix_1 + spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_times_matrix_2 + spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>> + } } // ----- diff --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir @@ -21,6 +21,20 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> } + + // CHECK-LABEL: @matrix_times_matrix_1 + spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_times_matrix_2 + spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>> + } } // ----- @@ -74,3 +88,39 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>> spv.Return } + +// ----- + +func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){ + // expected-error @+1 {{right and result matrices must have equal columns' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>> +} + +// ----- + +func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){ + // expected-error @+1 {{left and result matrices must have equal rows' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>> +} + +// ----- + +func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){ + // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>> +} + +// ----- + +func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){ + // expected-error @+1 {{right and result matrices' component type must be the same}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>> +} + + +// ----- + +func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){ + // expected-error @+1 {{left and result matrices' component type must be the same}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> +}