diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -326,7 +326,7 @@ if ((asProducer && returnType.getRank() < operandType.getRank()) || (!asProducer && operandType.getRank() < returnType.getRank())) return false; - return useIndexMap.isIdentity(); + return useIndexMap.isPermutation(); } /// Based on the type of `op` create a linalg op of the same type, i.e. if `op` @@ -381,10 +381,13 @@ return attr.cast().getValue(); })); + // Accepted consumer maps are either identity or permutation. + auto invMap = inversePermutation(fusedIndexMaps[consumerIdx]); + // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = linearizeCollapsedDims( - fusedIndexMaps[consumerIdx], producer.getResultType().getShape(), - producer.getReassociationMaps()); + AffineMap modifiedMap = + linearizeCollapsedDims(invMap, producer.getResultType().getShape(), + producer.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) return nullptr; @@ -439,10 +442,13 @@ producer.indexing_maps(), [](Attribute attr) -> AffineMap { return attr.cast().getValue(); })); + + auto invMap = inversePermutation(producer.getOutputIndexingMap(0)); + // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = linearizeCollapsedDims( - producer.getOutputIndexingMap(0), consumer.getSrcType().getShape(), - consumer.getReassociationMaps()); + AffineMap modifiedMap = + linearizeCollapsedDims(invMap, consumer.getSrcType().getShape(), + consumer.getReassociationMaps()); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) return nullptr; diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -558,3 +558,100 @@ // CHECK: linalg.indexed_generic // CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] // CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1 + d2 * 7)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_021_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<3x7x5xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<3x7x5xf32> + return %1 : tensor<3x7x5xf32> +} + +// CHECK-LABEL: func @generic_op_021_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d2, d0 * 7 + d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_120_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x7x3xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x7x3xf32> + return %1 : tensor<5x7x3xf32> +} + +// CHECK-LABEL: func @generic_op_120_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +#map0 = affine_map<(d0, d1, d2) -> (d0)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +func @generic_op_102_permultation_reshape_producer_fusion(%arg0 : tensor<3x35xf32>) -> tensor<5x3x7xf32> { + %0 = linalg.tensor_reshape %arg0 [#map0, #map1] : tensor<3x35xf32> into tensor<3x5x7xf32> + %1 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + return %1 : tensor<5x3x7xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_producer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape + +// ----- + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d1, d0 * 7 + d2)> + + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d2)> +func @generic_op_102_permultation_reshape_consumer_fusion(%arg0 : tensor<3x5x7xf32>) -> tensor<5x21xf32> { + %0 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<3x5x7xf32>) { + ^bb0(%arg2: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x3x7xf32> + %1 = linalg.tensor_reshape %0 [#map2, #map3] : tensor<5x3x7xf32> into tensor<5x21xf32> + return %1 : tensor<5x21xf32> +} + +// CHECK-LABEL: func @generic_op_102_permultation_reshape_consumer_fusion +// CHECK-NOT: linalg.tensor_reshape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-NOT: linalg.tensor_reshape