diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -169,9 +169,14 @@ SmallVector loops; SmallVector tensorResults; }; -FailureOr tileLinalgOp(OpBuilder &b, LinalgOp op, +FailureOr tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options); +/// Peel the loops of a TiledLinalgOp. +void peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, + ArrayRef peeledLoops, + LinalgTilingLoopType loopType); + /// Fuse a sequence of linalg operations (`ops`) using tile-and-fuse. This /// proceeds as follows: /// - Find outer parallel loops in these ops that can be fused. @@ -594,24 +599,35 @@ RewritePatternSet getLinalgTilingCanonicalizationPatterns(MLIRContext *ctx); void populateLinalgTilingCanonicalizationPatterns(RewritePatternSet &patterns); -/// Base pattern that applies the tiling transformation specified by `options`. -/// Abort and return failure in 2 cases: -/// 1. if the tiling specification is invalid and tiling fails to occur. -/// 2. if tiling occurs but `options.paddingValueComputationFunction` is set -/// and some operand shape cannot be bounded statically. -struct LinalgBaseTilingPattern : public RewritePattern { - // Entry point to match any LinalgOp OpInterface. - LinalgBaseTilingPattern( +/// +/// Linalg tiling pattern. +/// +/// Apply the `tiling` transformation as a pattern. +/// `filter` controls LinalgTransformMarker matching and update when specified. +/// See `tiling` for more details. +// TODO: TiledOpInterface +struct LinalgTilingPattern : public OpInterfaceRewritePattern { + /// Construct a generic pattern applied to all LinalgOp that verify `f`. + LinalgTilingPattern( MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - // Entry point to match a specific Linalg op. - LinalgBaseTilingPattern( + + /// Construct a pattern specifically applied to `opName`. + LinalgTilingPattern( StringRef opName, MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), + LinalgTransformationFilter f = LinalgTransformationFilter(), PatternBenefit benefit = 1); - LogicalResult matchAndRewriteBase(Operation *op, PatternRewriter &rewriter, - TiledLinalgOp &result) const; + + /// `matchAndRewrite` implementation that returns the significant transformed + /// pieces of IR. + FailureOr + returningMatchAndRewrite(LinalgOp op, PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(op, rewriter); + } private: /// LinalgTransformMarker handles special attribute manipulations. @@ -620,68 +636,6 @@ LinalgTilingOptions options; }; -template -struct LinalgTilingPattern : public LinalgBaseTilingPattern { - /// SFINAE: This constructor can only trigger for concrete ops that have a - /// static `getOperationName` method. - template - LinalgTilingPattern( - MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseTilingPattern(ConcreateOpTy::getOperationName(), context, - options, filter, benefit) {} - - /// This constructor is available to anyone. - LinalgTilingPattern( - StringRef opName, MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - TiledLinalgOp tiledLinalgOp; - if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter, - tiledLinalgOp))) - return failure(); - if (tiledLinalgOp.tensorResults.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, tiledLinalgOp.tensorResults); - return success(); - } -}; - -struct LinalgGenericTilingPattern : public LinalgBaseTilingPattern { - /// Entry point to match any LinalgOp OpInterface. - /// MatchAnyOpTag-based constructor with a mandatory `filter`. - LinalgGenericTilingPattern( - MLIRContext *context, LinalgTransformationFilter filter, - LinalgTilingOptions options = LinalgTilingOptions(), - PatternBenefit benefit = 1) - : LinalgBaseTilingPattern(context, options, filter, benefit) {} - /// Entry point to match a specific Linalg op. - LinalgGenericTilingPattern( - StringRef opName, MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : LinalgBaseTilingPattern(opName, context, options, filter, benefit) {} - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - TiledLinalgOp tiledLinalgOp; - if (failed(LinalgBaseTilingPattern::matchAndRewriteBase(op, rewriter, - tiledLinalgOp))) - return failure(); - if (tiledLinalgOp.tensorResults.empty()) - rewriter.eraseOp(op); - else - rewriter.replaceOp(op, tiledLinalgOp.tensorResults); - return success(); - } -}; - /// /// Linalg padding pattern. /// @@ -1395,6 +1349,32 @@ PatternRewriter &rewriter) const override; }; +//===----------------------------------------------------------------------===// +// Helper classes for type list expansion. +//===----------------------------------------------------------------------===// +template +class TilingPatterns; + +template <> +class TilingPatterns<> { +public: + static void insert(RewritePatternSet &patterns, + const LinalgTilingOptions &options, + const LinalgTransformationFilter &f) {} +}; + +template +class TilingPatterns { +public: + static void insert(RewritePatternSet &patterns, + const LinalgTilingOptions &options, + const LinalgTransformationFilter &f) { + patterns.add(OpTy::getOperationName(), + patterns.getContext(), options, f); + TilingPatterns::insert(patterns, options, f); + } +}; + } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -784,7 +784,9 @@ tileSizes[i] = zero; LinalgTilingOptions tileFusedLoopsOptions = options; tileFusedLoopsOptions.setTileSizes(tileSizes); - return tileLinalgOp(b, op, tileFusedLoopsOptions); + // TODO: Propagate RewriterBase everywhere. + IRRewriter rewriter(b); + return tileLinalgOp(rewriter, op, tileFusedLoopsOptions); } /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -283,10 +283,14 @@ tileInterchange.begin(), tileInterchange.end())) .setTileSizes(tileSizes) .setLoopType(LinalgTilingLoopType::Loops); - Optional tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions); + + // TODO: Propagate RewriterBase everywhere. + IRRewriter rewriter(b); + FailureOr tiledRootOp = + tileLinalgOp(rewriter, rootOp, tilingOptions); // Exit if tiling the root operation fails. - if (!tiledRootOp.hasValue()) + if (failed(tiledRootOp)) return failure(); // Replace all uses of the root operation if it has been tiled before. All diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -1,4 +1,4 @@ -//===- DynamicPass.cpp - Implementation of a dynamic configurable pass ----===// +//===- LinalgStrategyPasses.cpp - Implementation of Linalg passes ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -93,13 +93,14 @@ if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) return; - RewritePatternSet tilingPattern(funcOp.getContext()); + MLIRContext *ctx = funcOp.getContext(); + RewritePatternSet tilingPattern(ctx); if (!anchorOpName.empty()) { - tilingPattern.add( - anchorOpName, funcOp.getContext(), options, filter); + tilingPattern.add(anchorOpName, ctx, options, + filter); } else { - tilingPattern.add(funcOp.getContext(), filter, - options); + tilingPattern.add( + linalg::GenericOp::getOperationName(), ctx, options, filter); } (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -51,7 +51,7 @@ // a map from loop indices of the LinalgOp to the corresponding non-empty range // indices of newly created loops. static std::tuple, LoopIndexToRangeIndexMap> -makeTiledLoopRanges(OpBuilder &b, Location loc, AffineMap map, +makeTiledLoopRanges(RewriterBase &b, Location loc, AffineMap map, ValueRange allShapeSizes, ValueRange allTileSizes) { assert(allTileSizes.size() == map.getNumResults()); // Apply `map` to get shape sizes in loop order. @@ -129,7 +129,7 @@ // TODO: Investigate whether mixing implicit and explicit indices // does not lead to losing information. static void -transformIndexOps(OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, +transformIndexOps(RewriterBase &b, LinalgOp op, SmallVectorImpl &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { SmallVector allIvs(op.getNumLoops(), nullptr); for (auto &en : enumerate(allIvs)) { @@ -144,7 +144,7 @@ // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. -static Value insertSliceIntoTensor(OpBuilder &b, Location loc, +static Value insertSliceIntoTensor(RewriterBase &b, Location loc, tensor::ExtractSliceOp sliceOp, Value source, Value dest) { return b.create( @@ -155,7 +155,7 @@ template static FailureOr -tileLinalgOpImpl(OpBuilder &b, LinalgOp op, ValueRange tileSizes, +tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ValueRange tileSizes, const LinalgTilingOptions &options) { auto nLoops = op.getNumLoops(); // Initial tile sizes may be too big, only take the first nLoops. @@ -216,7 +216,7 @@ LinalgOp res = op; SmallVector ivs, tensorResults; auto tiledLoopBodyBuilder = - [&](OpBuilder &b, Location loc, ValueRange localIvs, + [&](OpBuilder &builder, Location loc, ValueRange localIvs, ValueRange operandValuesToUse) -> scf::ValueVector { ivs.assign(localIvs.begin(), localIvs.end()); @@ -255,9 +255,12 @@ // TODO: use an interface/adaptor to avoid leaking position in // `tiledOperands`. Value outputTensor = tiledOperands[opOperand->getOperandNumber()]; + // TODO: Propagate RewriterBase everywhere. + IRRewriter rewriter(b); if (auto sliceOp = outputTensor.getDefiningOp()) { - tensorResults.push_back(insertSliceIntoTensor( - b, loc, sliceOp, res->getResult(resultIdx), sliceOp.source())); + tensorResults.push_back(insertSliceIntoTensor(rewriter, loc, sliceOp, + res->getResult(resultIdx), + sliceOp.source())); } else { tensorResults.push_back(res->getResult(resultIdx)); } @@ -299,7 +302,7 @@ template FailureOr static tileLinalgOpImpl( - OpBuilder &b, LinalgOp op, const LinalgTilingOptions &options) { + RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -321,7 +324,7 @@ } FailureOr -mlir::linalg::tileLinalgOp(OpBuilder &b, LinalgOp op, +mlir::linalg::tileLinalgOp(RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { switch (options.loopType) { case LinalgTilingLoopType::Loops: @@ -338,7 +341,7 @@ /// Generate a loop nest around a given PadTensorOp (for tiling). `newPadOp` /// and `loopNest` are output parameters that return the new (tiled) PadTensorOp /// and the loop nest. -static LogicalResult tilePadTensorOp(OpBuilder &builder, PadTensorOp op, +static LogicalResult tilePadTensorOp(RewriterBase &builder, PadTensorOp op, PadTensorOp &newPadOp, LoopNest &loopNest, const LinalgTilingOptions &options) { Location loc = op.getLoc(); @@ -384,8 +387,10 @@ auto sliceOp = tiledOutput.getDefiningOp(); assert(sliceOp && "expected ExtractSliceOp"); // Insert the tile into the output tensor. + // TODO: Propagate RewriterBase everywhere. + IRRewriter rewriter(b); Value yieldValue = - insertSliceIntoTensor(b, loc, sliceOp, sliceOp, iterArgs[0]); + insertSliceIntoTensor(rewriter, loc, sliceOp, sliceOp, iterArgs[0]); return scf::ValueVector({yieldValue}); }); return success(); @@ -434,31 +439,6 @@ CanonicalizationPatternList::insert(patterns); } }; - -/// Helper classes for type list expansion. -template -class RewritePatternList; - -template <> -class RewritePatternList<> { -public: - static void insert(RewritePatternSet &patterns, - const LinalgTilingOptions &options) {} -}; - -template -class RewritePatternList { -public: - static void insert(RewritePatternSet &patterns, - const LinalgTilingOptions &options) { - auto *ctx = patterns.getContext(); - patterns.add>( - ctx, options, - LinalgTransformationFilter(ArrayRef{}, - StringAttr::get(ctx, "tiled"))); - RewritePatternList::insert(patterns, options); - } -}; } // namespace RewritePatternSet @@ -500,11 +480,14 @@ /// Populate the given list with patterns that apply Linalg tiling. static void insertTilingPatterns(RewritePatternSet &patterns, const LinalgTilingOptions &options) { - RewritePatternList{}, + StringAttr::get(ctx, "tiled")); + TilingPatterns::insert(patterns, options); - patterns.add(patterns.getContext(), options); + >::insert(patterns, options, f); + patterns.add(ctx, options); } static void applyExtractSliceOfPadTensorSwapPattern(FuncOp funcOp) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -1,4 +1,4 @@ -//===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===// +//===- Transforms.cpp - Linalg transformations as patterns ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -284,19 +284,6 @@ return paddedSubviewResults; } -/// Linalg base tiling pattern. -mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( - StringRef opName, MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(opName, benefit, context), filter(std::move(filter)), - options(std::move(options)) {} - -mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern( - MLIRContext *context, LinalgTilingOptions options, - LinalgTransformationFilter filter, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context), - filter(std::move(filter)), options(std::move(options)) {} - /// Try to peel a loop `op` and return the new result. // TODO: Add support for scf.parallel and affine.for loops. static SmallVector peelLoop(RewriterBase &rewriter, Operation *op) { @@ -325,14 +312,15 @@ } /// Peel loops after tiling. -static void peelLoops(RewriterBase &rewriter, TiledLinalgOp &res, - const LinalgTilingOptions &options) { - for (int64_t loop : options.peeledLoops) { +void mlir::linalg::peelTiledLinalgOp(RewriterBase &rewriter, TiledLinalgOp &res, + ArrayRef peeledLoops, + LinalgTilingLoopType loopType) { + for (int64_t loop : peeledLoops) { assert(loop < static_cast(res.loops.size()) && "requested peeling of non-existing loop"); SmallVector loopResults; Operation *loopOp = res.loops[loop]; - if (options.loopType == LinalgTilingLoopType::TiledLoops) { + if (loopType == LinalgTilingLoopType::TiledLoops) { assert(llvm::all_of( res.loops, [&](Operation *op) { return op == res.loops.front(); }) && @@ -352,28 +340,6 @@ } } -LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase( - Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const { - LinalgOp linalgOp = dyn_cast(op); - if (!linalgOp) - return failure(); - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - - Optional res = tileLinalgOp(rewriter, linalgOp, options); - - if (!res) - return failure(); - // Clear filter to stop recursive pattern application. - filter.replaceLinalgTransformationFilter(rewriter, res->op); - - // Peel loops. - peelLoops(rewriter, *res, options); - - result = *res; - return success(); -} - static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) { if (tiledOp.loops.empty()) return tiledOp.op.getOperation()->getResults(); @@ -459,9 +425,9 @@ })) { LinalgTilingOptions unfusedTilingOptions = tilingOptions; unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes); - Optional unfusedTiledOp = + FailureOr unfusedTiledOp = tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions); - if (!unfusedTiledOp) + if (failed(unfusedTiledOp)) return failure(); rewriter.replaceOp(tiledAndFusedOps->op, getTiledOpResult(unfusedTiledOp.getValue())); @@ -485,6 +451,48 @@ return success(); } +/// Linalg tiling pattern. +mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( + MLIRContext *context, LinalgTilingOptions options, + LinalgTransformationFilter f, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)), options(std::move(options)) {} + +mlir::linalg::LinalgTilingPattern::LinalgTilingPattern( + StringRef opName, MLIRContext *context, LinalgTilingOptions options, + LinalgTransformationFilter f, PatternBenefit benefit) + : OpInterfaceRewritePattern(context, benefit), + filter(std::move(f)), options(std::move(options)) { + this->filter.addFilter([opName](Operation *op) { + return success(op->getName().getStringRef() == opName); + }); +} + +FailureOr +mlir::linalg::LinalgTilingPattern::returningMatchAndRewrite( + LinalgOp op, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + FailureOr res = tileLinalgOp(rewriter, op, options); + if (failed(res)) + return failure(); + + // Clear filter to stop recursive pattern application. + // This must be done here to properly propagate to peeling branches. + filter.replaceLinalgTransformationFilter(rewriter, res->op); + + // Peel the loops of the TiledLinalgOp. + peelTiledLinalgOp(rewriter, *res, options.peeledLoops, options.loopType); + + if (res->tensorResults.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, res->tensorResults); + + return res; +} + /// Linalg padding pattern. mlir::linalg::LinalgPaddingPattern::LinalgPaddingPattern( MLIRContext *context, LinalgPaddingOptions options, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1178,8 +1178,9 @@ constexpr static StringRef kTiledMarker = "TILED"; constexpr static StringRef kPromotedMarker = "PROMOTED"; - tilingPatterns.add>( - context, LinalgTilingOptions().setTileSizes(tileSizes), + tilingPatterns.add( + ConvOp::getOperationName(), context, + LinalgTilingOptions().setTileSizes(tileSizes), LinalgTransformationFilter(ArrayRef{}, StringAttr::get(kTiledMarker, context))); diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -138,32 +138,36 @@ //===--------------------------------------------------------------------===// // Linalg tiling patterns. //===--------------------------------------------------------------------===// - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({2000, 3000, 4000}), LinalgTransformationFilter(StringAttr::get(ctx, "MEM"), StringAttr::get(ctx, "L3"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({200, 300, 400}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({200, 300, 400}), LinalgTransformationFilter(StringAttr::get(ctx, "L3"), StringAttr::get(ctx, "L2"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(StringAttr::get(ctx, "L2"), StringAttr::get(ctx, "L1"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({2, 3, 4}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({2, 3, 4}), LinalgTransformationFilter(StringAttr::get(ctx, "L1"), StringAttr::get(ctx, "REG"))); - patterns.add>( - ctx, + patterns.add( + MatvecOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes({5, 6}).setLoopType( LinalgTilingLoopType::ParallelLoops), LinalgTransformationFilter(ArrayRef{}, StringAttr::get(ctx, "L1"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes(8000), + patterns.add( + DotOp::getOperationName(), ctx, LinalgTilingOptions().setTileSizes(8000), LinalgTransformationFilter( ArrayRef{StringAttr::get(ctx, "MEM"), StringAttr::get(ctx, "L3"), @@ -173,32 +177,34 @@ //===--------------------------------------------------------------------===// // Linalg tiling and permutation patterns. //===--------------------------------------------------------------------===// - patterns.add>( - ctx, + patterns.add( + MatmulOp::getOperationName(), ctx, LinalgTilingOptions() .setTileSizes({2000, 3000, 4000}) .setInterchange({1, 2, 0}), LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), StringAttr::get(ctx, "L2__with_perm__"))); - patterns.add>( - ctx, + patterns.add( + MatmulOp::getOperationName(), ctx, LinalgTilingOptions() .setTileSizes({200, 300, 400}) .setInterchange({1, 0, 2}), LinalgTransformationFilter(StringAttr::get(ctx, "L2__with_perm__"), StringAttr::get(ctx, "L1__with_perm__"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({20, 30, 40}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({20, 30, 40}), LinalgTransformationFilter(StringAttr::get(ctx, "L1__with_perm__"), StringAttr::get(ctx, "REG__with_perm__"))); - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), + patterns.add( + MatvecOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({5, 6}).setInterchange({1, 0}), LinalgTransformationFilter(StringAttr::get(ctx, "__with_perm__"), StringAttr::get(ctx, "L1__with_perm__"))); - patterns.add>( - ctx, + patterns.add( + MatmulOp::getOperationName(), ctx, LinalgTilingOptions() .setTileSizes({16, 8, 4}) .setInterchange({1, 2, 0}) @@ -274,8 +280,8 @@ SmallVectorImpl &patternsVector) { MLIRContext *ctx = funcOp.getContext(); patternsVector.emplace_back( - ctx, std::make_unique>( - ctx, + ctx, std::make_unique( + MatmulOp::getOperationName(), ctx, LinalgTilingOptions() .setTileSizes({8, 12, 16}) .setInterchange({1, 0, 2}), @@ -339,8 +345,9 @@ static void fillPromotionCallBackPatterns(MLIRContext *ctx, RewritePatternSet &patterns) { - patterns.add>( - ctx, LinalgTilingOptions().setTileSizes({16, 16, 16}), + patterns.add( + MatmulOp::getOperationName(), ctx, + LinalgTilingOptions().setTileSizes({16, 16, 16}), LinalgTransformationFilter(StringAttr::get(ctx, "START"), StringAttr::get(ctx, "PROMOTE"))); patterns.add>( @@ -382,8 +389,8 @@ 2, DistributionMethod::CyclicNumProcsEqNumIters); cyclicNprocsEqNiters.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -399,8 +406,8 @@ 2, DistributionMethod::CyclicNumProcsGeNumIters); cyclicNprocsGeNiters.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -416,8 +423,8 @@ DistributionMethod::Cyclic); cyclicNprocsDefault.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -433,8 +440,8 @@ DistributionMethod::CyclicNumProcsEqNumIters, DistributionMethod::CyclicNumProcsGeNumIters}; cyclicNprocsMixed1.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -450,8 +457,8 @@ DistributionMethod::CyclicNumProcsGeNumIters, DistributionMethod::Cyclic}; cyclicNprocsMixed2.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -468,8 +475,8 @@ DistributionMethod::CyclicNumProcsEqNumIters}; cyclicNprocsMixed3.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::ParallelLoops) @@ -485,8 +492,8 @@ DistributionMethod::Cyclic); cyclicNprocsEqNiters.procInfo = getGpuProcIds; - patterns.add>( - context, + patterns.add( + MatmulOp::getOperationName(), context, LinalgTilingOptions() .setTileSizes({8, 8, 4}) .setLoopType(LinalgTilingLoopType::Loops) @@ -507,8 +514,8 @@ fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns); } else if (testMatmulToVectorPatterns2dTiling) { stage1Patterns.emplace_back( - ctx, std::make_unique>( - ctx, + ctx, std::make_unique( + MatmulOp::getOperationName(), ctx, LinalgTilingOptions() .setTileSizes({768, 264, 768}) .setInterchange({1, 2, 0}), @@ -589,10 +596,9 @@ } else { linalgTilingOptions.setTileSizes(tileSizes); } - tilingPattern.add, - linalg::LinalgTilingPattern>( - context, linalgTilingOptions, - linalg::LinalgTransformationFilter(StringAttr::get(context, "tile"))); + linalg::LinalgTransformationFilter f(StringAttr::get(context, "tile")); + TilingPatterns::insert( + tilingPattern, linalgTilingOptions, f); (void)applyPatternsAndFoldGreedily(funcOp, std::move(tilingPattern)); }