diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -62,6 +62,79 @@ - !ScalarExpression scalar_arg: B --- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: mmt4d + cpp_class_name: Mmt4DOp + doc: |- + Performs a matrix-matrix-transpose multiplication of two 4D inputs. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - !LinalgOperandDefConfig + name: lhs + usage: InputOperand + type_var: LhsType + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)> + - !LinalgOperandDefConfig + name: rhs + usage: InputOperand + type_var: RhsType + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)> + - !LinalgOperandDefConfig + name: accum + usage: OutputOperand + type_var: AccumType + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)> + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d1, + d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d2, d4, d3, + d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d1, + d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + assignments: + - !ScalarAssign + arg: accum + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: accum + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: AccumType + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + symbolic_cast: + type_var: AccumType + operands: + - !ScalarExpression + scalar_arg: rhs +--- !LinalgOpConfig metadata: !LinalgOpMetadata name: batch_matmul cpp_class_name: BatchMatmulOp diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -21,6 +21,26 @@ C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, + output=True)): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + """ + domain(D.m, D.m0, D.n, D.n0, D.k, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + + @linalg_structured_op def batch_matmul( A=TensorDef(T1, Batch, S.M, S.K),