diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -349,6 +349,23 @@ return nullptr; } +/// Check if the reshape operation is only expansion into/collapsing of +/// unit-dimension. +static bool isUnitDimExpansionOnly(ArrayRef expandedShape, + ArrayRef reassociation) { + for (auto &map : reassociation) { + unsigned numUnitDims = 0; + for (AffineExpr expr : map.getResults()) { + unsigned position = expr.cast().getPosition(); + if (expandedShape[position] == 1) + numUnitDims++; + } + if (numUnitDims != map.getNumResults() - 1) + return false; + } + return true; +} + /// Conditions for folding a generic/indexed-generic operation with a reshape op /// by expanding the iteration space dimensionality for tensor operations. These /// are preconditions assumed by `foldReshapeByDimExpansion` which implements @@ -858,7 +875,9 @@ // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(linalgOp, operand.index())) + !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps())) continue; Optional> replacementValues = @@ -949,7 +968,10 @@ return failure(); LinalgOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs())) + !isFusableWithReshapeByDimExpansion(producer, + producer.getNumInputs()) || + isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps())) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -188,42 +188,6 @@ // ----- -func @scalar_reshape( - %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32> -{ - %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor - %1 = linalg.init_tensor [10] : tensor<10xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%0 : tensor) - outs(%1 : tensor<10xf32>) { - ^bb0(%arg2: f32, %s: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<10xf32> - %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] - : tensor<10xf32> into tensor<1x10xf32> - return %3 : tensor<1x10xf32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()> -// CHECK: func @scalar_reshape -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] [] -// CHECK-SAME: tensor<1xf32> into tensor -// CHECK: %[[T1:.+]] = linalg.init_tensor [10] -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]] -// CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]] : tensor) -// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>) -// CHECK: return %[[T3]] : tensor<1x10xf32> - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor, @@ -336,7 +300,7 @@ %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -493,3 +457,77 @@ // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) // CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: return %[[T3]] : tensor + +// ----- + +func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1) -> (d0, d1)>] : tensor<1x5xf32> into tensor<5xf32> + %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %2 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + return %2 : tensor<5x5xf32> +} +// CHECK: func @unit_dim_reshape_expansion +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic + +// ----- + +func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> { + %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + %2 = linalg.tensor_reshape %1 + [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<5x5xf32> into tensor<5x1x5xf32> + return %2 : tensor<5x1x5xf32> +} +// CHECK: func @unit_dim_reshape_collapse +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK: linalg.tensor_reshape + +// ----- + +func @unit_dim_reshape_expansion_full + (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor) + -> tensor { + %c1 = constant 1 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<1x?x1x2x1x4xf32> into tensor + %1 = dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> + %2 = linalg.init_tensor [%1, 2, 4] : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %4 = mulf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK: func @unit_dim_reshape_expansion_full +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic