diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -730,12 +730,34 @@
}
};
+/// Fold tensor.extract(linalg.fill()) into
+struct FoldFillWithTensorExtract : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // See if tensor input of tensor.extract op is the result of a linalg.fill op.
+ auto fillOp = extractOp.getTensor().getDefiningOp();
+ if (!fillOp)
+ return failure();
+
+ // Get scalar input operand of linalg.fill op.
+ Value extractedScalar = fillOp.getInputs()[0];
+
+ // Replace tensor.extract op with scalar value used to fill the tensor.
+ rewriter.replaceOp(extractOp, extractedScalar);
+ return success();
+ }
+};
+
} // namespace
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
- .add,
+ .add,
FoldFillWithTensorReshape,
FoldInsertPadIntoFill>(context);
}
diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir
--- a/mlir/test/Dialect/Linalg/canonicalize.mlir
+++ b/mlir/test/Dialect/Linalg/canonicalize.mlir
@@ -335,6 +335,22 @@
return %1 : tensor
}
+// -----
+// CHECK: func @fold_fill_extract
+// CHECK-SAME: %[[ARG0:.+]]: i1
+func.func @fold_fill_extract(%arg0 : i1) -> i1 {
+ %c0 = arith.constant 0 : index
+ %c1 = arith.constant 1 : index
+
+ %empty_dynamic = tensor.empty(%c1) : tensor<1x2x3x?xi1>
+ %filled = linalg.fill ins(%arg0 : i1) outs(%empty_dynamic : tensor<1x2x3x?xi1>) -> tensor<1x2x3x?xi1>
+
+ %extracted = tensor.extract %filled[%c0, %c0, %c0, %c0] : tensor<1x2x3x?xi1>
+
+ // CHECK: return %[[ARG0]]
+ return %extracted : i1
+}
+
// -----
// CHECK: func @fold_self_copy