Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
===================================================================
--- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -635,6 +635,29 @@
}
};
+/// Fold tensor.extract(linalg.fill()) into
+struct FoldFillWithTensorExtract : public OpRewritePattern {
+public:
+ using OpRewritePattern::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ // see if tensor input of tensor.extract op is result of linalg.fill op
+ auto fillOp = extractOp.getTensor().getDefiningOp();
+ if (!fillOp) {
+ return failure();
+ }
+
+ // get scalar input operand of linalg.fill
+ Value extractedScalar = fillOp.getInputs()[0];
+
+ // replace tensor.extract op with op that simply produces the scalar
+ rewriter.replaceOpWithNewOp(
+ extractOp, extractedScalar.getType(), extractedScalar);
+ return success();
+ }
+};
+
/// Fold tensor.insert_slice(tensor.pad(), linalg.fill) into
/// tensor.insert_slice(, linalg.fill) if the padding value and the
/// filling value are the same.
@@ -735,7 +758,8 @@
void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results
- .add,
+ .add,
FoldFillWithTensorReshape,
FoldInsertPadIntoFill>(context);
}