diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -187,40 +187,13 @@ return failure(); SmallVector dims = genericOp.getStaticShape(); - // Find all the reduction iterators. Those need some special consideration - // (see below). - auto getLoopDimsOfType = - [&](StringRef iteratorTypeName) -> SmallVector { - SmallVector dimExprs; - getDimsOfType(genericOp, iteratorTypeName, dimExprs); - return llvm::to_vector<4>(llvm::map_range(dimExprs, [](AffineExpr expr) { - return expr.cast().getPosition(); - })); - }; - auto reductionDims = getLoopDimsOfType(getReductionIteratorTypeName()); - DenseSet unitDims; SmallVector unitDimsReductionLoops; ArrayAttr iteratorTypes = genericOp.iterator_types(); for (auto expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = expr.value().dyn_cast()) - if (dims[dimExpr.getPosition()] == 1) { - if (isParallelIterator(iteratorTypes[expr.index()])) - unitDims.insert(expr.index()); - else if (isReductionIterator(iteratorTypes[expr.index()])) - unitDimsReductionLoops.push_back(expr.index()); - } - } - - // Reduction loops can be dropped if there is at least one other reduction - // loop that is not dropped. This accounts for the initial value read in the - // reduction loop. - if (!unitDimsReductionLoops.empty() && reductionDims.size() > 1) { - if (unitDimsReductionLoops.size() == reductionDims.size()) - unitDims.insert(reductionDims.begin(), std::prev(reductionDims.end())); - else - unitDims.insert(unitDimsReductionLoops.begin(), - unitDimsReductionLoops.end()); + if (dims[dimExpr.getPosition()] == 1) + unitDims.insert(expr.index()); } if (unitDims.empty()) diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -361,7 +361,7 @@ // ----- -func @unit_dim_for_reduction_keep_one(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { +func @unit_dim_for_both_reduction(%arg0: tensor<1x?x1x1xf32>) -> tensor<1x1xf32> { %cst = constant 1.000000e+00 : f32 %c3 = constant 3 : index %1 = linalg.init_tensor [1, 1] : tensor<1x1xf32> @@ -378,17 +378,16 @@ } -> tensor<1x1xf32> return %3 : tensor<1x1xf32> } -// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1) -> (d0)> -// CHECK: func @unit_dim_for_reduction_keep_one +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @unit_dim_for_both_reduction // CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x1xf32> -// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] +// CHECK-DAG: %[[RESHAPE:.+]] = linalg.tensor_collapse_shape %[[ARG0]] {{\[}}[0, 1, 2, 3] // CHECK: %[[INIT:.+]] = linalg.init_tensor [1] : tensor<1xf32> // CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[INIT]]) // CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]]] -// CHECK-SAME: iterator_types = ["parallel", "reduction"] -// CHECK-SAME: ins(%[[RESHAPE]] : tensor) +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK-SAME: outs(%[[FILL]] : tensor<1xf32>) // CHECK: %[[RESULT_RESHAPE:.+]] = linalg.tensor_expand_shape %[[RESULT]] {{\[}}[0, 1]] // CHECK: return %[[RESULT_RESHAPE]]