diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1658,12 +1658,33 @@ return success(); } }; + +/// Fold linalg.fill -> linalg.tensor_reshape chain. +/// +/// For such op chains, we can create new linalg.fill ops with the result +/// type of the linalg.tensor_reshape op. +struct FoldFillWithTensorReshape : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp, + PatternRewriter &rewriter) const override { + auto oldFill = reshapeOp.src().getDefiningOp(); + if (!oldFill) + return failure(); + + auto newInit = rewriter.create( + oldFill.getLoc(), reshapeOp.getResultType().getShape(), + reshapeOp.getResultType().getElementType()); + rewriter.replaceOpWithNewOp(reshapeOp, newInit, oldFill.value()); + + return success(); + } +}; } // namespace void TensorReshapeOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add, FoldReshapeWithConstant, - ReplaceDimOfReshapeOpResult>(context); + results.add, FoldFillWithTensorReshape, + FoldReshapeWithConstant, ReplaceDimOfReshapeOpResult>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -802,3 +802,19 @@ // CHECK: return return } + +// ----- + +// CHECK-LABEL: func @fold_fill_reshape() +func @fold_fill_reshape() -> tensor<6x4xf32> { + %zero = constant 0.0 : f32 + // CHECK: %[[INIT:.+]] = linalg.init_tensor [6, 4] : tensor<6x4xf32> + %init = linalg.init_tensor [1, 2, 3, 4] : tensor<1x2x3x4xf32> + // CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor<6x4xf32>, f32 -> tensor<6x4xf32> + %fill = linalg.fill(%init, %zero) : tensor<1x2x3x4xf32>, f32 -> tensor<1x2x3x4xf32> + %reshape = linalg.tensor_reshape %fill [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3) -> (d3)>] : tensor<1x2x3x4xf32> into tensor<6x4xf32> + // CHECK: return %[[FILL]] : tensor<6x4xf32> + return %reshape : tensor<6x4xf32> +} diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -379,12 +379,14 @@ // CHECK: func @fold_unit_dim_for_init_tensor + // CHECK: %[[INPUT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [#[[MAP0]]] : tensor<1x1000xf32> into tensor<1000xf32> -// CHECK: %[[INIT_RESHAPE:.+]] = linalg.tensor_reshape %{{.+}} [] : tensor<1xf32> into tensor +// CHECK: %[[INIT:.+]] = linalg.init_tensor [] : tensor +// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %cst) : tensor, f32 -> tensor // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] // CHECK-SAME: iterator_types = ["reduction"] // CHECK-SAME: ins(%[[INPUT_RESHAPE]] : tensor<1000xf32>) -// CHECK-SAME: outs(%[[INIT_RESHAPE]] : tensor) +// CHECK-SAME: outs(%[[FILL]] : tensor) // CHECK: %[[GENERIC_RESHAPE:.+]] = linalg.tensor_reshape %[[GENERIC]] [] : tensor into tensor<1xf32> // CHECK: return %[[GENERIC_RESHAPE:.+]] : tensor<1xf32>