diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -650,6 +650,29 @@ } }; +/// Swaps tensor.extract_slice(linalg.fill(%cst, %init)) into linalg.fill(%cst, +/// tensor.extract_slice(%init)) when the linalg.fill op have no other users. +/// This helps to reduce the fill footprint. +struct SwapExtractSliceOfFill final + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractOp, + PatternRewriter &rewriter) const override { + auto fillOp = extractOp.getSource().getDefiningOp(); + if (!fillOp || !fillOp->hasOneUse()) + return failure(); + + auto newExtractOp = rewriter.create( + extractOp.getLoc(), extractOp.getType(), fillOp.getOutputs()[0], + extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()); + rewriter.replaceOpWithNewOp(extractOp, fillOp.getInputs(), + ValueRange{newExtractOp.getResult()}); + return success(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -657,7 +680,7 @@ results .add, FoldFillWithTensorReshape, - FoldInsertPadIntoFill>(context); + FoldInsertPadIntoFill, SwapExtractSliceOfFill>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -927,3 +927,32 @@ // CHECK-SAME: outs(%[[INIT2]], %[[INIT1]] : // CHECK: %[[RETURN_CAST:.+]] = tensor.cast %[[GENERIC]]#0 : tensor<3x2x4xf32> to tensor // CHECK: return %[[RETURN_CAST]], %[[GENERIC]]#1 + +// ----- + +// CHECK-LABEL: func.func @swap_fill_insert_slice +// CHECK-SAME: (%[[INIT:.+]]: tensor, %[[OFFSET0:.+]]: index, %[[SIZE1:.+]]: index) +// CHECK: %[[F0:.+]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EXT:.+]] = tensor.extract_slice %[[INIT]][%[[OFFSET0]], 8, 4] [1, %[[SIZE1]], 6] [1, 3, 1] +// CHECK: %[[FILL:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[EXT]] : tensor) -> tensor +// CHECK: return %[[FILL]] +func.func @swap_fill_insert_slice(%init : tensor, %offset0: index, %size1: index) -> tensor { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [1, %size1, 6] [1, 3, 1] + : tensor to tensor + return %1: tensor +} + +// ----- + +// CHECK-LABEL: func.func @dont_swap_fill_insert_slice_multi_user +// CHECK: linalg.fill +// CHECK: tensor.extract_slice +func.func @dont_swap_fill_insert_slice_multi_user(%init : tensor, %offset0: index, %size1: index) -> (tensor, tensor<2x?x6xf32>) { + %f0 = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill ins(%f0 : f32) outs(%init : tensor) -> tensor + %1 = tensor.extract_slice %0[%offset0, 8, 4] [2, %size1, 6] [1, 3, 1] + : tensor to tensor<2x?x6xf32> + return %0, %1: tensor, tensor<2x?x6xf32> +}