diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -730,12 +730,34 @@ } }; +/// Fold tensor.extract(linalg.fill()) into +struct FoldFillWithTensorExtract : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // See if tensor input of tensor.extract op is the result of a linalg.fill op. + auto fillOp = extractOp.getTensor().getDefiningOp(); + if (!fillOp) + return failure(); + + // Get scalar input operand of linalg.fill op. + Value extractedScalar = fillOp.getInputs()[0]; + + // Replace tensor.extract op with scalar value used to fill the tensor. + rewriter.replaceOp(extractOp, extractedScalar); + return success(); + } +}; + } // namespace void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results - .add, + .add, FoldFillWithTensorReshape, FoldInsertPadIntoFill>(context); } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -335,6 +335,22 @@ return %1 : tensor } +// ----- +// CHECK: func @fold_fill_extract +// CHECK-SAME: %[[ARG0:.+]]: i1 +func.func @fold_fill_extract(%arg0 : i1) -> i1 { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + + %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1> + %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1> + + %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1> + + // CHECK: return %[[ARG0]] + return %extracted : i1 +} + // ----- // CHECK: func @fold_self_copy