diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -367,46 +367,18 @@ } namespace { -/// Lower tensor.from_elements to a sequence of chained tensor.insert. -struct FromElementsOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FromElementsOp fromElementsOp, - PatternRewriter &rewriter) const override { - if (failed( - linalg::rewriteInDestinationPassingStyle(rewriter, fromElementsOp))) - return failure(); - return success(); - } -}; -/// Lower tensor.generate to linalg.generic. -struct GenerateOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenerateOp generateOp, - PatternRewriter &rewriter) const override { - if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, generateOp))) - return failure(); - return success(); - } -}; +template +LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, + PatternRewriter &rewriter) { + return linalg::rewriteInDestinationPassingStyle(rewriter, op); +} -/// Lower tensor.pad to linalg.generic + tensor.insert_slice. -struct PadOpConverter : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadOp padOp, - PatternRewriter &rewriter) const override { - if (failed(linalg::rewriteInDestinationPassingStyle(rewriter, padOp))) - return failure(); - return success(); - } -}; } // namespace void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { - patterns.insert( - patterns.getContext()); + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); }