diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -908,11 +908,9 @@ // the operands of the consumers that arent fused are the same. SmallVector fusedIndexMaps = genericOp.getIndexingMaps(); - // Accepted consumer maps are either identity or permutation. - auto invMap = inversePermutation(fusedIndexMaps[en.index()]); - // Compute the indexing map to use for the result of the producer. - AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); + AffineMap modifiedMap = + linearizeCollapsedDims(fusedIndexMaps[en.index()], reshapeOp); // The modified map cannot have symbols. if (modifiedMap.getNumSymbols()) return failure(); @@ -1159,11 +1157,9 @@ // those for the operands of the producer. SmallVector fusedIndexMaps = producer.getIndexingMaps(); - auto invMap = inversePermutation( - producer.getTiedIndexingMap(producer.getOutputOperand(0))); - // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); + AffineMap modifiedMap = linearizeCollapsedDims( + producer.getTiedIndexingMap(producer.getOutputOperand(0)), reshapeOp); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) { return rewriter.notifyMatchFailure( diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion.mlir @@ -1,5 +1,11 @@ // RUN: mlir-opt -split-input-file -linalg-fold-reshape-ops-by-linearization %s | FileCheck %s +// Note: These tests fuse the reshape ops by linearization. This can create +// indexing maps which are hard to analyse later on. These patterns are useful +// only if the folded dimensions in the reshape op are unit extent. Tests here +// are more general for testing purposes, but use of these pattern for non-unit +// dimensions should be deprecated. + #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func @generic_op_reshape_producer_fusion(%arg0 : tensor) -> tensor { @@ -227,3 +233,55 @@ // CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] // CHECK-SAME: ins(%[[ARG0]] : tensor<6x1xf32>) // CHECK-SAME: outs(%[[T1]] : tensor<6xi32>) + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> +func @permuted_dims_fusion_expand_shape(%arg0 : tensor<3x8x7x240xf32>) -> tensor<4x6x3x8x2x5x7xf32> { + %0 = tensor.expand_shape %arg0 [[0], [1, 2], [3], [4, 5, 6]] + : tensor<3x8x7x240xf32> into tensor<3x2x4x7x8x5x6xf32> + %1 = linalg.init_tensor [4, 6, 3, 8, 2, 5, 7] : tensor<4x6x3x8x2x5x7xf32> + %2 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%0 : tensor<3x2x4x7x8x5x6xf32>) outs(%1 : tensor<4x6x3x8x2x5x7xf32>) { + ^bb0(%arg1 : f32, %arg2 : f32): + linalg.yield %arg1 : f32 + } -> tensor<4x6x3x8x2x5x7xf32> + return %2 : tensor<4x6x3x8x2x5x7xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> +// CHECK: func @permuted_dims_fusion_expand_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x8x7x240xf32>) +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK: return %[[RESULT]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d4, d0, d6, d3, d5, d1)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> +func @permuted_dims_fusion_collapse_shape(%arg0 : tensor<4x6x3x8x2x5x7xf32>) -> tensor<3x8x7x240xf32> { + %0 = linalg.init_tensor [3, 2, 4, 7, 8, 5, 6] : tensor<3x2x4x7x8x5x6xf32> + %1 = linalg.generic { + indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<4x6x3x8x2x5x7xf32>) outs(%0 : tensor<3x2x4x7x8x5x6xf32>) { + ^bb0(%arg1 : f32, %arg2 : f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2x4x7x8x5x6xf32> + %2 = tensor.collapse_shape %1 [[0], [1, 2], [3], [4, 5, 6]] + : tensor<3x2x4x7x8x5x6xf32> into tensor<3x8x7x240xf32> + return %2 : tensor<3x8x7x240xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5, d6)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2, d0 + d4 * 4, d6, d1 + d3 * 30 + d5 * 6)> +// CHECK: func @permuted_dims_fusion_collapse_shape( +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x3x8x2x5x7xf32>) +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK: return %[[RESULT]]