Index: mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -635,6 +635,29 @@ } }; +/// Fold tensor.extract(linalg.fill()) into +struct FoldFillWithTensorExtract : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp extractOp, + PatternRewriter &rewriter) const override { + // see if tensor input of tensor.extract op is result of linalg.fill op + auto fillOp = extractOp.getTensor().getDefiningOp(); + if (!fillOp) { + return failure(); + } + + // get scalar input operand of linalg.fill + Value extractedScalar = fillOp.getInputs()[0]; + + // replace tensor.extract op with op that simply produces the scalar + rewriter.replaceOpWithNewOp( + extractOp, extractedScalar.getType(), extractedScalar); + return success(); + } +}; + /// Fold tensor.insert_slice(tensor.pad(), linalg.fill) into /// tensor.insert_slice(, linalg.fill) if the padding value and the /// filling value are the same. @@ -735,7 +758,8 @@ void FillOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results - .add, + .add, FoldFillWithTensorReshape, FoldInsertPadIntoFill>(context); }