diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -463,6 +463,36 @@ } }; +/// Swap extract_slice(fill) to fill(extract_slice). +/// +/// Only swap the two ops if the extract_slice is the only user of the fill. +struct SwapExtractSliceOfFill : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(tensor::ExtractSliceOp extractSliceOp, + PatternRewriter &rewriter) const override { + auto oldFill = extractSliceOp.getSource().getDefiningOp(); + if (!oldFill) + return failure(); + // Only swap the ops if there is no other user of the fill. + if (!extractSliceOp.getSource().hasOneUse()) + return failure(); + // Extract from the old fill's source. + rewriter.updateRootInPlace(extractSliceOp, [&]() { + extractSliceOp.getSourceMutable().assign(oldFill.output()); + }); + // Create a new fill and remove the old one. + rewriter.setInsertionPointAfter(extractSliceOp); + auto newFill = + rewriter.create(oldFill.getLoc(), ValueRange{oldFill.value()}, + ValueRange{extractSliceOp.getResult()}); + rewriter.eraseOp(oldFill); + // Use the new fill instead of the extract_slice. + rewriter.replaceAllUsesExcept(extractSliceOp.getResult(), + newFill.getResult(0), newFill); + return success(); + } +}; + /// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the /// filling value are the same. struct FoldFillWithPad final : public OpRewritePattern { @@ -607,7 +637,7 @@ results .add, FoldFillWithTensorReshape, - FoldInsertPadIntoFill>(context); + FoldInsertPadIntoFill, SwapExtractSliceOfFill>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -312,6 +312,20 @@ // ----- +// CHECK-LABEL: func @fold_fill_extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor<1x1xf32> +func.func @fold_fill_extract_slice(%t: tensor<1x1xf32>) -> (tensor) { + %cst = arith.constant 0.000000e+00 : f32 + // CHECK: %[[e:.*]] = tensor.extract_slice %[[t]] + // CHECK: %[[f:.*]] = linalg.fill {{.*}} outs(%[[e]] : tensor) + %0 = linalg.fill ins(%cst : f32) outs(%t : tensor<1x1xf32>) -> tensor<1x1xf32> + %1 = tensor.extract_slice %0[0, 0] [1, 1] [1, 1] : tensor<1x1xf32> to tensor + // CHECK: return %[[f]] + return %1 : tensor +} + +// ----- + // CHECK: func @fold_fill_reshape_dynamic // CHECK-SAME: %[[ARG0:.+]]: tensor func.func @fold_fill_reshape_dynamic(%arg0 : tensor) -> tensor {