diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -3141,6 +3141,7 @@ def SPV_OC_OpCompositeConstruct : I32EnumAttrCase<"OpCompositeConstruct", 80>; def SPV_OC_OpCompositeExtract : I32EnumAttrCase<"OpCompositeExtract", 81>; def SPV_OC_OpCompositeInsert : I32EnumAttrCase<"OpCompositeInsert", 82>; +def SPV_OC_OpTranspose : I32EnumAttrCase<"OpTranspose", 84>; def SPV_OC_OpConvertFToU : I32EnumAttrCase<"OpConvertFToU", 109>; def SPV_OC_OpConvertFToS : I32EnumAttrCase<"OpConvertFToS", 110>; def SPV_OC_OpConvertSToF : I32EnumAttrCase<"OpConvertSToF", 111>; @@ -3265,20 +3266,21 @@ SPV_OC_OpFunctionParameter, SPV_OC_OpFunctionEnd, SPV_OC_OpFunctionCall, SPV_OC_OpVariable, SPV_OC_OpLoad, SPV_OC_OpStore, SPV_OC_OpAccessChain, SPV_OC_OpDecorate, SPV_OC_OpMemberDecorate, SPV_OC_OpCompositeConstruct, - SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpConvertFToU, - SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, SPV_OC_OpConvertUToF, - SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, SPV_OC_OpBitcast, - SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, SPV_OC_OpISub, SPV_OC_OpFSub, - SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, SPV_OC_OpSDiv, SPV_OC_OpFDiv, - SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, SPV_OC_OpFRem, SPV_OC_OpFMod, - SPV_OC_OpMatrixTimesScalar, SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, - SPV_OC_OpLogicalOr, SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, - SPV_OC_OpIEqual, SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, - SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, - SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, - SPV_OC_OpSLessThanEqual, SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, - SPV_OC_OpFOrdNotEqual, SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, - SPV_OC_OpFUnordLessThan, SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, + SPV_OC_OpCompositeExtract, SPV_OC_OpCompositeInsert, SPV_OC_OpTranspose, + SPV_OC_OpConvertFToU, SPV_OC_OpConvertFToS, SPV_OC_OpConvertSToF, + SPV_OC_OpConvertUToF, SPV_OC_OpUConvert, SPV_OC_OpSConvert, SPV_OC_OpFConvert, + SPV_OC_OpBitcast, SPV_OC_OpFNegate, SPV_OC_OpIAdd, SPV_OC_OpFAdd, + SPV_OC_OpISub, SPV_OC_OpFSub, SPV_OC_OpIMul, SPV_OC_OpFMul, SPV_OC_OpUDiv, + SPV_OC_OpSDiv, SPV_OC_OpFDiv, SPV_OC_OpUMod, SPV_OC_OpSRem, SPV_OC_OpSMod, + SPV_OC_OpFRem, SPV_OC_OpFMod, SPV_OC_OpMatrixTimesScalar, + SPV_OC_OpLogicalEqual, SPV_OC_OpLogicalNotEqual, SPV_OC_OpLogicalOr, + SPV_OC_OpLogicalAnd, SPV_OC_OpLogicalNot, SPV_OC_OpSelect, SPV_OC_OpIEqual, + SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, + SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, + SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, + SPV_OC_OpFOrdEqual, SPV_OC_OpFUnordEqual, SPV_OC_OpFOrdNotEqual, + SPV_OC_OpFUnordNotEqual, SPV_OC_OpFOrdLessThan, SPV_OC_OpFUnordLessThan, + SPV_OC_OpFOrdGreaterThan, SPV_OC_OpFUnordGreaterThan, SPV_OC_OpFOrdLessThanEqual, SPV_OC_OpFUnordLessThanEqual, SPV_OC_OpFOrdGreaterThanEqual, SPV_OC_OpFUnordGreaterThanEqual, SPV_OC_OpShiftRightLogical, SPV_OC_OpShiftRightArithmetic, diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVMatrixOps.td @@ -45,6 +45,13 @@ ``` }]; + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Matrix]> + ]; + let arguments = (ins SPV_AnyMatrix:$matrix, SPV_Float:$scalar @@ -72,4 +79,58 @@ // ----- +def SPV_TransposeOp : SPV_Op<"Transpose", []> { + let summary = "Transpose a matrix."; + + let description = [{ + Result Type must be an OpTypeMatrix. + + Matrix must be an object of type OpTypeMatrix. The number of columns and + the column size of Matrix must be the reverse of those in Result Type. + The types of the scalar components in Matrix and Result Type must be the + same. + + Matrix must have of type of OpTypeMatrix. + + + + ``` + transpose-op ::= ssa-id `=` `spv.Transpose` ssa-use `:` matrix-type `->` + matrix-type + + ```mlir + + #### Example: + + ``` + %0 = spv.Transpose %matrix: !spv.matrix<2 x vector<3xf32>> -> + !spv.matrix<3 x vector<2xf32>> + + ``` + }]; + + let availability = [ + MinVersion, + MaxVersion, + Extension<[]>, + Capability<[SPV_C_Matrix]> + ]; + + let arguments = (ins + SPV_AnyMatrix:$matrix + ); + + let results = (outs + SPV_AnyMatrix:$result + ); + + let assemblyFormat = [{ + operands attr-dict `:` type($matrix) `->` type($result) + }]; + + let verifier = [{ return verifyTranspose(*this); }]; +} + +// ----- + #endif // SPIRV_MATRIX_OPS \ No newline at end of file diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -2815,6 +2815,36 @@ return success(); } +//===----------------------------------------------------------------------===// +// spv.Transpose +//===----------------------------------------------------------------------===// + +static LogicalResult verifyTranspose(spirv::TransposeOp op) { + auto inputMatrix = op.matrix().getType().cast(); + auto resultMatrix = op.result().getType().cast(); + + // Verify that the input and output matrices have correct shapes. + if (auto inputMatrixColumns = + inputMatrix.getElementType().dyn_cast()) { + if (inputMatrixColumns.getNumElements() != resultMatrix.getNumElements()) + return op.emitError("input matrix rows count must be equal to " + "output matrix columns count"); + if (auto resultMatrixColumns = + resultMatrix.getElementType().dyn_cast()) { + if (resultMatrixColumns.getNumElements() != inputMatrix.getNumElements()) + return op.emitError("input matrix columns count must be equal " + "to output matrix rows count"); + + // Verify that the input and output matrices have the same component type + if (inputMatrixColumns.getElementType() != + resultMatrixColumns.getElementType()) + return op.emitError("input and output matrices must have the " + "same component type"); + } + } + return success(); +} + namespace mlir { namespace spirv { diff --git a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir --- a/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir +++ b/mlir/test/Dialect/SPIRV/Serialization/matrix.mlir @@ -22,6 +22,13 @@ spv.ReturnValue %result : !spv.matrix<3 x vector<3xf16>> } + + // CHECK-LABEL: @matrix_transpose_1 + spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>> + } } // ----- diff --git a/mlir/test/Dialect/SPIRV/matrix-ops.mlir b/mlir/test/Dialect/SPIRV/matrix-ops.mlir --- a/mlir/test/Dialect/SPIRV/matrix-ops.mlir +++ b/mlir/test/Dialect/SPIRV/matrix-ops.mlir @@ -2,11 +2,25 @@ spv.module Logical GLSL450 requires #spv.vce { // CHECK-LABEL: @matrix_times_scalar - spv.func @matrix_times_scalar_1(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" { + spv.func @matrix_times_scalar(%arg0 : !spv.matrix<3 x vector<3xf32>>, %arg1 : f32) -> !spv.matrix<3 x vector<3xf32>> "None" { // CHECK: {{%.*}} = spv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<3 x vector<3xf32>> spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> } + + // CHECK-LABEL: @matrix_transpose_1 + spv.func @matrix_transpose_1(%arg0 : !spv.matrix<3 x vector<2xf32>>) -> !spv.matrix<2 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<2xf32>> -> !spv.matrix<2 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<2 x vector<3xf32>> + } + + // CHECK-LABEL: @matrix_transpose_2 + spv.func @matrix_transpose_2(%arg0 : !spv.matrix<3 x vector<3xf32>>) -> !spv.matrix<3 x vector<3xf32>> "None" { + // CHECK: {{%.*}} = spv.Transpose {{%.*}} : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<3xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.ReturnValue %result : !spv.matrix<3 x vector<3xf32>> + } } // ----- @@ -37,5 +51,26 @@ %result = spv.MatrixTimesScalar %arg0, %arg1 : !spv.matrix<3 x vector<3xf32>>, f32 -> !spv.matrix<4 x vector<3xf32>> } +// ----- + +func @transpose_op_shape_mismatch_1(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () { + // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}} + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<3 x vector<3xf32>> + spv.Return +} + +// ----- +func @transpose_op_shape_mismatch_2(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () { + // expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}} + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<2 x vector<4xf32>> + spv.Return +} +// ----- + +func @transpose_op_type_mismatch(%arg0 : !spv.matrix<3 x vector<4xf32>>) -> () { + // expected-error @+1 {{input and output matrices must have the same component type}} + %result = spv.Transpose %arg0 : !spv.matrix<3 x vector<4xf32>> -> !spv.matrix<4 x vector<3xf16>> + spv.Return +}