diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml @@ -949,19 +949,19 @@ indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)> + s9, s10, s11, s12] -> (d0, d1 * s9 + d5 * s11, d2 * s10 + d6 * s12, d3)> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d3, d4, d5, d6)> + s9, s10, s11, s12] -> (d5, d6, d3, d4)> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)> + s9, s10, s11, s12] -> (d0, d1, d2, d3, d4)> iterator_types: - parallel - parallel - parallel - - reduction - - reduction - parallel - parallel + - reduction + - reduction assignments: - !ScalarAssign arg: O @@ -1039,23 +1039,23 @@ indexing_maps: !LinalgIndexingMapsConfig static_indexing_maps: - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d0, d1 * s9 + d3 * s11, d2 * s10 + d4 * s12, d5)> + s9, s10, s11, s12] -> (d0, d1 * s9 + d5 * s11, d2 * s10 + d6 * s12, d3)> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d3, d4, d5, d6)> + s9, s10, s11, s12] -> (d5, d6, d3, d4)> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] -> ()> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11, s12] -> ()> - affine_map<(d0, d1, d2, d3, d4, d5, d6)[s0, s1, s2, s3, s4, s5, s6, s7, s8, - s9, s10, s11, s12] -> (d0, d1, d2, d5, d6)> + s9, s10, s11, s12] -> (d0, d1, d2, d3, d4)> iterator_types: - parallel - parallel - parallel - - reduction - - reduction - parallel - parallel + - reduction + - reduction assignments: - !ScalarAssign arg: O diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -209,7 +209,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) @@ -228,7 +228,7 @@ Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - cast(U, IZp)) *