diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -2994,10 +2994,12 @@ def SPV_IsArrayType : CPred<"$_self.isa<::mlir::spirv::ArrayType>()">; def SPV_IsCooperativeMatrixType : CPred<"$_self.isa<::mlir::spirv::CooperativeMatrixNVType>()">; +def SPV_IsMatrixType : CPred<"$_self.isa<::mlir::spirv::MatrixType>()">; def SPV_IsPtrType : CPred<"$_self.isa<::mlir::spirv::PointerType>()">; def SPV_IsRTArrayType : CPred<"$_self.isa<::mlir::spirv::RuntimeArrayType>()">; def SPV_IsStructType : CPred<"$_self.isa<::mlir::spirv::StructType>()">; + // See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types // for the definition of the following types and type categories. @@ -3018,6 +3020,8 @@ def SPV_AnyCooperativeMatrix : DialectType; +def SPV_AnyMatrix : DialectType; def SPV_AnyRTArray : DialectType; def SPV_AnyStruct : DialectType; def SPV_Composite : AnyTypeOf<[SPV_Vector, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix]>; + SPV_AnyCooperativeMatrix, SPV_AnyMatrix]>; def SPV_Type : AnyTypeOf<[ SPV_Void, SPV_Bool, SPV_Integer, SPV_Float, SPV_Vector, SPV_AnyPtr, SPV_AnyArray, SPV_AnyRTArray, SPV_AnyStruct, - SPV_AnyCooperativeMatrix + SPV_AnyCooperativeMatrix, SPV_AnyMatrix ]>; def SPV_SignlessOrUnsignedInt : SignlessOrUnsignedIntOfWidths<[8, 16, 32, 64]>; @@ -3160,6 +3164,7 @@ def SPV_OC_OpSMod : I32EnumAttrCase<"OpSMod", 139>; def SPV_OC_OpFRem : I32EnumAttrCase<"OpFRem", 140>; def SPV_OC_OpFMod : I32EnumAttrCase<"OpFMod", 141>; +def SPV_OC_OpMatrixTimesScalar : I32EnumAttrCase<"OpMatrixTimesScalar", 143>; def SPV_OC_OpLogicalEqual : I32EnumAttrCase<"OpLogicalEqual", 164>; def SPV_OC_OpLogicalNotEqual : I32EnumAttrCase<"OpLogicalNotEqual", 165>; def SPV_OC_OpLogicalOr : I32EnumAttrCase<"OpLogicalOr", 166>; @@ -3266,14 +3271,14 @@ SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, - SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, - SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, - SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, - SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, - SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, - SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, + SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, + SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, + SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, + SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, + SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, + SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, + SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td @@ -0,0 +1,75 @@ +//===-- SPIRVMatrixOps.td - MLIR SPIR-V Matrix Ops ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains matrix operations for the SPIR-V dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef SPIRV_MATRIX_OPS +#define SPIRV_MATRIX_OPS + +// ----- + +def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", []> { + let summary = "Scale a floating-point matrix."; + + let description = [{ + Result Type must be an OpTypeMatrix whose Column Type is a vector of + floating-point type. + + The type of Matrix must be the same as Result Type. Each component in + each column in Matrix is multiplied by Scalar. + + Scalar must have the same type as the Component Type in Result Type. + + + + ``` + matrix-times-scalar-op ::= ssa-id `=` `spv.MatrixTimesScalar` ssa-use, + ssa-use `:` matrix-type `,` float-type `->` matrix-type + + ``` + + #### Example: + + ```mlir + + %0 = spv.MatrixTimesScalar %matrix, %scalar : + !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> + + ``` + }]; + + let arguments = (ins + SPV_AnyMatrix:$matrix, + SPV_Float:$scalar + ); + + let results = (outs + SPV_AnyMatrix:$result + ); + + // TODO (Hazem): we need just one matrix type given that the input and result + // are the same and the scalar's type can be deduced from it. + let assemblyFormat = [{ + operands attr-dict `:` type($matrix) `,` type($scalar) `->` type($result) + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Matrix]> + ]; + + let verifier = [{ return verifyMatrixTimesScalar(*this); }]; +} + +// ----- + +#endif // SPIRV_MATRIX_OPS \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVOps.td @@ -32,6 +32,7 @@ include "mlir/Dialect/SPIRV/SPIRVGLSLOps.td" include "mlir/Dialect/SPIRV/SPIRVGroupOps.td" include "mlir/Dialect/SPIRV/SPIRVLogicalOps.td" +include "mlir/Dialect/SPIRV/SPIRVMatrixOps.td" include "mlir/Dialect/SPIRV/SPIRVNonUniformOps.td" include "mlir/Dialect/SPIRV/SPIRVStructureOps.td" include "mlir/Interfaces/SideEffectInterfaces.td" diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2760,6 +2760,49 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.MatrixTimesScalar +//===----------------------------------------------------------------------===// + +static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) { + // We already checked that result and matrix are both of matrix type in the + // auto-generated verify method. + + auto inputMatrix = op.matrix().getType().cast(); + // Check that the scalar type is the same as the matrix components type. + if (auto inputMatrixColumns = + inputMatrix.getElementType().dyn_cast()) { + if (op.scalar().getType() != inputMatrixColumns.getElementType()) + return op.emitError("input matrix components' type and scaling " + "value must have the same type"); + + // Note that the next three checks could be done using the AllTypesMatch + // trait in the Op definition file but it generates a vague error message. + + // Check that the input and result matrices have the same size + auto resultMatrix = op.result().getType().cast(); + if (inputMatrix.getNumElements() != resultMatrix.getNumElements()) + return op.emitError("input and result matrices must have " + "the same number of columns"); + + if (auto resultMatrixColumns = + resultMatrix.getElementType().dyn_cast()) { + // Check that the input and result matrices' columns have the same type + if (inputMatrixColumns.getElementType() != + resultMatrixColumns.getElementType()) + return op.emitError("input and result matrices' columns must " + "have the same component type"); + + // Check that the input and result matrices' columns have the same size + if (inputMatrixColumns.getNumElements() != + resultMatrixColumns.getNumElements()) + return op.emitError("input and result matrices' columns must " + "have the same size"); + } + } + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -1,10 +1,25 @@ // RUN: mlir-translate -split-input-file -test-spirv-roundtrip %s | FileCheck %s spv.module Logical GLSL450 requires #spv.vce { - spv.func @matrix_type(%arg0 : !spv.ptr>, StorageBuffer>, %arg1 : i32) "None" { - // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, StorageBuffer> - %2 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, StorageBuffer> - spv.Return + // CHECK-LABEL: @matrix_access_chain + spv.func @matrix_access_chain(%arg0 : !spv.ptr>, Function>, %arg1 : i32) -> !spv.ptr, Function> "None" { + // CHECK: {{%.*}} = spv.AccessChain {{%.*}}[{{%.*}}] : !spv.ptr>, Function> + %0 = spv.AccessChain %arg0[%arg1] : !spv.ptr>, Function> + spv.ReturnValue %0 : !spv.ptr, Function> + } + + // CHECK-LABEL: @matrix_times_scalar_1 + spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_times_scalar_2 + spv.func @matrix_times_scalar_2(%arg0 : !spv.matrix<3 x vector<3xf16>>, %arg1 : f16) -> !spv.matrix<3 x vector<3xf16>> "None" { + // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf16>>, f16 -> !spv.matrix<3 x vector<3xf16>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>> } } diff --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s + +spv.module Logical GLSL450 requires #spv.vce { + // CHECK-LABEL: @matrix_times_scalar + spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } +} + +// ----- + +func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f16) -> () { + // expected-error @+1 {{input matrix components' type and scaling value must have the same type}} + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f16 -> !spv.matrix<3 x vector<3xf32>> +} + +// ----- + +func @input_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f64) -> () { + // expected-error @+1 {{input matrix components' type and scaling value must have the same type}} + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f64 -> !spv.matrix<3 x vector<3xf32>> +} + +// ----- + +func @input_output_component_type_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () { + // expected-error @+1 {{input and result matrices' columns must have the same component type}} + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf64>> +} + +// ----- + +func @input_output_size_mismatch(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> () { + // expected-error @+1 {{input and result matrices must have the same number of columns}} + %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>> +} + + +