diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -3,7 +3,7 @@ name: matmul cpp_op_name: MatmulOp doc: |- - Performs a matrix multiplacation of two 2D inputs. + Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -305,4 +305,77 @@ operands: - !ScalarExpression scalar_arg: B +--- !LinalgOpConfig +metadata: !LinalgOpMetadata + name: mmt_4d_kernel + cpp_op_name: Mmt4DKernelOp + doc: |- + A lowering path for linalg.matmul towards efficient code generation on CPU. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + implements: + - LinalgContractionOpInterface +structured_op: !LinalgStructuredOpConfig + args: + - ! + name: lhs + usage: input + shape: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)> + element_type_var: LhsType + - ! + name: rhs + usage: input + shape: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s1, s4, s3, s5)> + element_type_var: RhsType + - ! + name: accum + usage: output + shape: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s1, s2, s3)> + element_type_var: AccumType + indexing_maps: !LinalgIndexingMapsConfig + static_indexing_maps: + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d2, + d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d1, d4, d3, + d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2, + d3)> + iterator_types: + - parallel + - parallel + - parallel + - parallel + - reduction + - reduction + assignments: + - !ScalarAssign + arg: accum + value: !ScalarExpression + scalar_apply: + fn_name: add + operands: + - !ScalarExpression + scalar_arg: accum + - !ScalarExpression + scalar_apply: + fn_name: mul + operands: + - !ScalarExpression + symbolic_cast: + type_var: AccumType + operands: + - !ScalarExpression + scalar_arg: lhs + - !ScalarExpression + symbolic_cast: + type_var: AccumType + operands: + - !ScalarExpression + scalar_arg: rhs diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,7 @@ def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplacation of two 2D inputs. + """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -68,3 +68,22 @@ """ implements(ContractionOpInterface) C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + + +@linalg_structured_op +def mmt_4d_kernel(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, + output=True)): + """A lowering path for linalg.matmul towards efficient code generation on CPU. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + """ + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0])