diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -181,7 +181,7 @@ name: rhs usage: InputOperand type_var: RhsType - shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s3, s5)> + shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s4, s1, s5, s3)> - !LinalgOperandDefConfig name: accum usage: OutputOperand @@ -189,19 +189,19 @@ shape_map: affine_map<()[s0, s1, s2, s3, s4, s5] -> (s0, s4, s2, s5)> indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d4, d2, + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d2, d3, d5)> - - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d1, d4, d5, - d3)> - - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d2, - d3)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d1, d2, d4, + d5)> + - affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5] -> (d0, d1, d3, + d4)> iterator_types: - parallel - parallel + - reduction - parallel - parallel - reduction - - reduction assignments: - !ScalarAssign arg: accum diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -39,7 +39,7 @@ @linalg_structured_op def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.K0, S.N0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): """Performs a matrix-matrix-transpose multiplication of two 4D inputs. @@ -52,9 +52,9 @@ '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0. """ - domain(D.m, D.n, D.m0, D.n0, D.k, D.k0) + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.k0, D.n0]) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op def batch_matmul(