diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -294,8 +294,8 @@ using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { - // TODO: support init_tensors and reductions. - if (!op.hasTensorSemantics() || op.getNumInitTensors() != 0) + // TODO: support reductions. + if (!op.hasTensorSemantics()) return failure(); MLIRContext *context = rewriter.getContext(); diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -354,3 +354,37 @@ // CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK: return %[[RESULT]] + +// ----- + +func @fold_unit_dim_for_init_tensor(%input: tensor<1x1000xf32>) -> tensor<1xf32> { + %cst = constant 0.0 : f32 + %init = linalg.init_tensor [1] : tensor<1xf32> + %fill = linalg.fill(%init, %cst) : tensor<1xf32>, f32 -> tensor<1xf32> + %add = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%input : tensor<1x1000xf32>)outs(%fill : tensor<1xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %1823 = addf %arg1, %arg2 : f32 + linalg.yield %1823 : f32 + } -> tensor<1xf32> + return %add : tensor<1xf32> +} + + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()> + +// CHECK: func @fold_unit_dim_for_init_tensor + +// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [#[[MAP0]]] : tensor<1x1000xf32> into tensor<1000xf32> +// CHECK: %[[INIT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] : tensor<1xf32> into tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction"] +// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>) +// CHECK-SAME: outs(%[[INIT_RESHAPE]] : tensor) +// CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor into tensor<1xf32> +// CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>