diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3167,6 +3167,7 @@ def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; +def SPV_OC_OpMatrixTimesMatrix : I32EnumAttrCase<"OpMatrixTimesMatrix", 146>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; @@ -3274,38 +3275,38 @@ SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, - SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, - SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, - SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, - SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, SPV_OC_OpBitwiseXor, - SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, SPV_OC_OpBitFieldInsert, - SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, SPV_OC_OpBitReverse, - SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, SPV_OC_OpMemoryBarrier, - SPV_OC_OpAtomicCompareExchangeWeak, SPV_OC_OpAtomicIIncrement, - SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, SPV_OC_OpAtomicISub, - SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, SPV_OC_OpAtomicSMax, - SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, - SPV_OC_OpPhi, SPV_OC_OpLoopMerge, SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, - SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, - SPV_OC_OpReturnValue, SPV_OC_OpUnreachable, SPV_OC_OpNoLine, - SPV_OC_OpModuleProcessed, SPV_OC_OpGroupNonUniformElect, - SPV_OC_OpGroupNonUniformBallot, SPV_OC_OpGroupNonUniformIAdd, - SPV_OC_OpGroupNonUniformFAdd, SPV_OC_OpGroupNonUniformIMul, - SPV_OC_OpGroupNonUniformFMul, SPV_OC_OpGroupNonUniformSMin, - SPV_OC_OpGroupNonUniformUMin, SPV_OC_OpGroupNonUniformFMin, - SPV_OC_OpGroupNonUniformSMax, SPV_OC_OpGroupNonUniformUMax, - SPV_OC_OpGroupNonUniformFMax, SPV_OC_OpSubgroupBallotKHR, - SPV_OC_OpTypeCooperativeMatrixNV, SPV_OC_OpCooperativeMatrixLoadNV, - SPV_OC_OpCooperativeMatrixStoreNV, SPV_OC_OpCooperativeMatrixMulAddNV, - SPV_OC_OpCooperativeMatrixLengthNV + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpMatrixTimesMatrix, SPV_OC_OpLogicalEqual, + SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, + SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, SPV_OC_OpINotEqual, + SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, + SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, + SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, + SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, + SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, + SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, + SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, + SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, + SPV_OC_OpShiftRightArithmetic, SPV_OC_OpShiftLeftLogical, SPV_OC_OpBitwiseOr, + SPV_OC_OpBitwiseXor, SPV_OC_OpBitwiseAnd, SPV_OC_OpNot, + SPV_OC_OpBitFieldInsert, SPV_OC_OpBitFieldSExtract, SPV_OC_OpBitFieldUExtract, + SPV_OC_OpBitReverse, SPV_OC_OpBitCount, SPV_OC_OpControlBarrier, + SPV_OC_OpMemoryBarrier, SPV_OC_OpAtomicCompareExchangeWeak, + SPV_OC_OpAtomicIIncrement, SPV_OC_OpAtomicIDecrement, SPV_OC_OpAtomicIAdd, + SPV_OC_OpAtomicISub, SPV_OC_OpAtomicSMin, SPV_OC_OpAtomicUMin, + SPV_OC_OpAtomicSMax, SPV_OC_OpAtomicUMax, SPV_OC_OpAtomicAnd, + SPV_OC_OpAtomicOr, SPV_OC_OpAtomicXor, SPV_OC_OpPhi, SPV_OC_OpLoopMerge, + SPV_OC_OpSelectionMerge, SPV_OC_OpLabel, SPV_OC_OpBranch, + SPV_OC_OpBranchConditional, SPV_OC_OpReturn, SPV_OC_OpReturnValue, + SPV_OC_OpUnreachable, SPV_OC_OpNoLine, SPV_OC_OpModuleProcessed, + SPV_OC_OpGroupNonUniformElect, SPV_OC_OpGroupNonUniformBallot, + SPV_OC_OpGroupNonUniformIAdd, SPV_OC_OpGroupNonUniformFAdd, + SPV_OC_OpGroupNonUniformIMul, SPV_OC_OpGroupNonUniformFMul, + SPV_OC_OpGroupNonUniformSMin, SPV_OC_OpGroupNonUniformUMin, + SPV_OC_OpGroupNonUniformFMin, SPV_OC_OpGroupNonUniformSMax, + SPV_OC_OpGroupNonUniformUMax, SPV_OC_OpGroupNonUniformFMax, + SPV_OC_OpSubgroupBallotKHR, SPV_OC_OpTypeCooperativeMatrixNV, + SPV_OC_OpCooperativeMatrixLoadNV, SPV_OC_OpCooperativeMatrixStoreNV, + SPV_OC_OpCooperativeMatrixMulAddNV, SPV_OC_OpCooperativeMatrixLengthNV ]>; // End opcode section. Generated from SPIR-V spec; DO NOT MODIFY! diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td @@ -12,10 +12,65 @@ #ifndef SPIRV_MATRIX_OPS #define SPIRV_MATRIX_OPS +include "mlir/Interfaces/SideEffectInterfaces.td" // ----- -def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> { +def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> { + let summary = "Linear-algebraic multiply of LeftMatrix X RightMatrix."; + + let description = [{ + Result Type must be an OpTypeMatrix whose Column Type is a vector of + floating-point type. + + LeftMatrix must be a matrix whose Column Type is the same as the Column + Type in Result Type. + + RightMatrix must be a matrix with the same Component Type as the + Component Type in Result Type. Its number of columns must equal the + number of columns in Result Type. Its columns must have the same number + of components as the number of columns in LeftMatrix. + + + + ``` + matrix-times-matrix-op ::= ssa-id `=` `spv.MatrixTimesMatrix` ssa-use, + ssa-use `:` matrix-type `,` matrix-type `->` matrix-type + ```mlir + + #### Example: + + ``` + %0 = spv.MatrixTimesMatrix %matrix_1, %matrix_2 : + !spv.matrix<4 x vector<3xf32>>, !spv.matrix<3 x vector<4xf32>> -> + !spv.matrix<4 x vector<4xf32>> + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Matrix]> + ]; + + let arguments = (ins + SPV_AnyMatrix:$leftmatrix, + SPV_AnyMatrix:$rightmatrix + ); + + let results = (outs + SPV_AnyMatrix:$result + ); + let assemblyFormat = [{ + operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result) + }]; + let verifier = [{ return verifyMatrixTimesMatrix(*this); }]; +} + +// ----- + +def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> { let summary = "Scale a floating-point matrix."; let description = [{ @@ -79,7 +134,7 @@ // ----- -def SPV_TransposeOp : SPV_Op<"Transpose", []> { +def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> { let summary = "Transpose a matrix."; let description = [{ diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2924,6 +2924,56 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.MatrixTimesMatrix +//===----------------------------------------------------------------------===// + +static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) { + auto leftMatrix = op.leftmatrix().getType().cast(); + auto rightMatrix = op.rightmatrix().getType().cast(); + auto resultMatrix = op.result().getType().cast(); + + if (auto rightMatrixColumn = + rightMatrix.getElementType().dyn_cast()) { + + // left matrix columns count and right matrix rows count must be equal + if (leftMatrix.getNumElements() != rightMatrixColumn.getNumElements()) { + return op.emitError("left matrix columns' count must be equal to " + "the right matrix rows' count"); + } + // right and result matrices columns' count must be the same + if (rightMatrix.getNumElements() != resultMatrix.getNumElements()) + return op.emitError( + "right and result matrices must have equal columns' count"); + + // right and result matrices component type must be the same + if (auto resultMatrixColumn = + resultMatrix.getElementType().dyn_cast()) { + if (rightMatrixColumn.getElementType() != + resultMatrixColumn.getElementType()) { + return op.emitError("right and result matrices' component " + "type must be the same"); + } + if (auto leftMatrixColumn = + leftMatrix.getElementType().dyn_cast()) { + // left and result matrices component type must be the same + if (leftMatrixColumn.getElementType() != + resultMatrixColumn.getElementType()) + return op.emitError("left and result matrices' component type" + " must be the same"); + + // left and result matrices rows count must be the same + if (leftMatrixColumn.getNumElements() != + resultMatrixColumn.getNumElements()) { + return op.emitError( + "left and result matrices must have equal rows' count"); + } + } + } + } + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -29,6 +29,20 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>> } + + // CHECK-LABEL: @matrix_times_matrix_1 + spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_times_matrix_2 + spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>> + } } // ----- diff --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir @@ -21,6 +21,20 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> } + + // CHECK-LABEL: @matrix_times_matrix_1 + spv.func @matrix_times_matrix_1(%arg0: !spv.matrix<3 x vector<3xf32>>, %arg1: !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_times_matrix_2 + spv.func @matrix_times_matrix_2(%arg0: !spv.matrix<3 x vector<2xf32>>, %arg1: !spv.matrix<2 x vector<3xf32>>) -> !spv.matrix<2 x vector<2xf32>> "None"{ + // CHECK: {{%.*}} = spv.MatrixTimesMatrix {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<2xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<2xf32>> + } } // ----- @@ -74,3 +88,39 @@ %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>> spv.Return } + +// ----- + +func @matrix_times_matrix_invalid_input_shape_1(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){ + // expected-error @+1 {{right and result matrices must have equal columns' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<3 x vector<2xf32>> +} + +// ----- + +func @matrix_times_matrix_invalid_input_shape_2(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<3xf32>>){ + // expected-error @+1 {{left and result matrices must have equal rows' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<3xf32>> -> !spv.matrix<2 x vector<3xf32>> +} + +// ----- + +func @matrix_times_matrix_inputs_shape_mismatch(%arg0 : !spv.matrix<3 x vector<2xf32>>, %arg1 : !spv.matrix<2 x vector<2xf32>>){ + // expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<2xf32>>, !spv.matrix<2 x vector<2xf32>> -> !spv.matrix<2 x vector<2xf32>> +} + +// ----- + +func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : !spv.matrix<3x vector<3xf32>>){ + // expected-error @+1 {{right and result matrices' component type must be the same}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf64>> +} + + +// ----- + +func @matrix_times_matrix_component_type_mismatch_2(%arg0 : !spv.matrix<3 x vector<3xf64>>, %arg1 : !spv.matrix<3x vector<3xf32>>){ + // expected-error @+1 {{left and result matrices' component type must be the same}} + %result = spv.MatrixTimesMatrix %arg0, %arg1 : !spv.matrix<3 x vector<3xf64>>, !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> +}