diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -81,6 +81,12 @@ //===----------------------------------------------------------------------===// /// Linalg strategy passes. //===----------------------------------------------------------------------===// +/// Create a LinalgStrategyTileAndFusePass. +std::unique_ptr> createLinalgStrategyTileAndFusePass( + StringRef opName = "", linalg::LinalgTilingAndFusionOptions opt = {}, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter()); + /// Create a LinalgStrategyTilePass. std::unique_ptr> createLinalgStrategyTilePass( StringRef opName = "", diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -235,6 +235,18 @@ let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect"]; } +def LinalgStrategyTileAndFusePass + : FunctionPass<"linalg-strategy-tile-and-fuse-pass"> { + let summary = "Configurable pass to apply pattern-based tiling and fusion."; + let constructor = "mlir::createLinalgStrategyTileAndFusePass()"; + let options = [ + Option<"anchorFuncName", "anchor-func", "std::string", /*default=*/"", + "Which func op is the anchor to latch on.">, + Option<"anchorOpName", "anchor-op", "std::string", /*default=*/"", + "Which linalg op within the func is the anchor to latch on.">, + ]; +} + def LinalgStrategyTilePass : FunctionPass<"linalg-strategy-tile-pass"> { let summary = "Configurable pass to apply pattern-based linalg tiling."; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/CodegenStrategy.h @@ -30,6 +30,22 @@ LinalgTransformationFilter::FilterFunction filter = nullptr; }; +/// Represent one application of LinalgStrategyTileAndFusePass. +struct TileAndFuse : public Transformation { + TileAndFuse(StringRef name, linalg::LinalgTilingAndFusionOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) + : Transformation(f), opName(name), options(options) {} + + void addToPassPipeline(OpPassManager &pm, + LinalgTransformationFilter m) const override { + pm.addPass(createLinalgStrategyTileAndFusePass(opName, options, m)); + } + +private: + std::string opName; + linalg::LinalgTilingAndFusionOptions options; +}; + /// Represent one application of LinalgStrategyTilePass. struct Tile : public Transformation { Tile(StringRef name, linalg::LinalgTilingOptions options, @@ -147,6 +163,22 @@ /// Codegen strategy controls how a Linalg op is progressively lowered. struct CodegenStrategy { + /// Append a pattern to tile the Op `opName` and fuse its producers with + /// tiling and fusion `options`. + CodegenStrategy & + tileAndFuse(StringRef opName, LinalgTilingAndFusionOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) { + transformationSequence.emplace_back( + std::make_unique(opName, options, f)); + return *this; + } + /// Conditionally append a pattern to tile the Op `opName` and fuse its + /// producers with tiling and fusion `options`. + CodegenStrategy & + tileAndFuseIf(bool b, StringRef opName, LinalgTilingAndFusionOptions options, + LinalgTransformationFilter::FilterFunction f = nullptr) { + return b ? tileAndFuse(opName, options, f) : *this; + } /// Append a pattern to add a level of tiling for Op `opName` with tiling /// `options`. CodegenStrategy & @@ -161,7 +193,7 @@ CodegenStrategy & tileIf(bool b, StringRef opName, linalg::LinalgTilingOptions options, LinalgTransformationFilter::FilterFunction f = nullptr) { - return b ? tile(opName, options) : *this; + return b ? tile(opName, options, f) : *this; } /// Append a pattern to pad and hoist the operands of Op `opName` with padding /// `options`. diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -517,6 +517,14 @@ } }; +struct LinalgTilingAndFusionOptions { + /// Tile sizes used to tile the root operation. + SmallVector tileSizes; + /// Tile interchange used to permute the tile loops relative to the root + /// operation loop dimensions. + SmallVector tileInterchange; +}; + struct LinalgTilingOptions { /// Computation function that returns the tile sizes for each operation. /// Delayed construction of constant tile sizes should occur to interoperate @@ -767,6 +775,34 @@ fusionOptions, filter, fusedOpMarker, originalOpMarker, benefit) {} }; +/// +/// Linalg tile and fuse tensor ops pattern. +/// +/// Apply tiling and fusion as a pattern. +/// `filter` controls LinalgTransformMarker matching and update when specified. +/// See `tileConsumerAndFuseProducers` for more details. +struct LinalgTileAndFuseTensorOpsPattern : public RewritePattern { + // Entry point to match any LinalgOp. + LinalgTileAndFuseTensorOpsPattern( + MLIRContext *context, LinalgTilingAndFusionOptions options, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + // Entry point to match a specific LinalgOp. + LinalgTileAndFuseTensorOpsPattern( + StringRef opName, MLIRContext *context, + LinalgTilingAndFusionOptions options, + LinalgTransformationFilter filter = LinalgTransformationFilter(), + PatternBenefit benefit = 1); + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; + /// Tile sizes and interchange used to tile the root operation. + LinalgTilingAndFusionOptions options; +}; + /// /// Linalg generic interchage pattern. /// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -235,6 +235,9 @@ /// Returns the tiled root operation. LinalgOp getRootOp() { return rootOp; } + /// Returns the tiled root operation and the fused producers. + SmallVector getAllTiledAndFusedOps(); + private: /// Returns true if the tile loop nest has no tile loops. bool isEmpty(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -390,6 +390,17 @@ return tileLoopOps.front()->getOpResults(); } +SmallVector TileLoopNest::getAllTiledAndFusedOps() { + SmallVector result; + for (const auto &kvp : tiledRootAndFusedOpsLoops) { + auto linalgOp = dyn_cast(kvp.getFirst()); + assert(linalgOp && + "expect all tiled and fused operations are linalg operations"); + result.push_back(linalgOp); + } + return result; +} + //===----------------------------------------------------------------------===// // Tile and fuse entry-points. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgStrategyPasses.cpp @@ -36,6 +36,43 @@ namespace { +/// Configurable pass to apply pattern-based tiling and fusion. +struct LinalgStrategyTileAndFusePass + : public LinalgStrategyTileAndFusePassBase { + + LinalgStrategyTileAndFusePass() = default; + + LinalgStrategyTileAndFusePass(StringRef opName, + LinalgTilingAndFusionOptions opt, + LinalgTransformationFilter filt) + : options(opt), filter(filt) { + this->anchorOpName.setValue(opName.str()); + } + + void runOnFunction() override { + auto funcOp = getFunction(); + if (!anchorFuncName.empty() && funcOp.getName() != anchorFuncName) + return; + + RewritePatternSet tilingAndFusionPattern(funcOp.getContext()); + if (!anchorOpName.empty()) { + tilingAndFusionPattern.add( + anchorOpName, funcOp.getContext(), options, filter); + } else { + tilingAndFusionPattern.add( + funcOp.getContext(), options, filter); + } + // Search the root operation using bottom up traversal. + GreedyRewriteConfig grc; + grc.useTopDownTraversal = false; + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(tilingAndFusionPattern), grc); + } + + LinalgTilingAndFusionOptions options; + LinalgTransformationFilter filter; +}; + /// Configurable pass to apply pattern-based linalg tiling. struct LinalgStrategyTilePass : public LinalgStrategyTilePassBase { @@ -380,6 +417,15 @@ }; } // namespace +/// Create a LinalgStrategyTileAndFusePass. +std::unique_ptr> +mlir::createLinalgStrategyTileAndFusePass(StringRef opName, + LinalgTilingAndFusionOptions options, + LinalgTransformationFilter filter) { + return std::make_unique(opName, options, + filter); +} + /// Create a LinalgStrategyTilePass. std::unique_ptr> mlir::createLinalgStrategyTilePass(StringRef opName, LinalgTilingOptions opt, diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -511,6 +511,75 @@ return success(); } +/// Linalg tile and fuse tensor ops pattern. +mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: + LinalgTileAndFuseTensorOpsPattern(MLIRContext *context, + LinalgTilingAndFusionOptions options, + LinalgTransformationFilter filter, + PatternBenefit benefit) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter), + options(options) {} + +mlir::linalg::LinalgTileAndFuseTensorOpsPattern:: + LinalgTileAndFuseTensorOpsPattern(StringRef opName, MLIRContext *context, + LinalgTilingAndFusionOptions options, + LinalgTransformationFilter filter, + PatternBenefit benefit) + : RewritePattern(opName, benefit, context), filter(filter), + options(options) {} + +LogicalResult mlir::linalg::LinalgTileAndFuseTensorOpsPattern::matchAndRewrite( + Operation *op, PatternRewriter &rewriter) const { + LinalgOp rootOp = dyn_cast(op); + if (!rootOp) + return failure(); + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + + // Check `tileSizes` contains a tile size for every `rootOp` loop dimension. + if (options.tileSizes.size() < rootOp.getNumLoops()) + return rewriter.notifyMatchFailure(op, "expect #tile sizes >= #loops"); + + // Check `tileInterchange` contains no entries or as many as `tileSizes`. + if (!options.tileInterchange.empty() && + options.tileInterchange.size() != options.tileSizes.size()) + return rewriter.notifyMatchFailure( + op, "expect the number of tile sizes and interchange dims to match"); + + // Copy the `tileSizes` and `tileInterchange` prefixes needed for `rootOp`. + SmallVector rootTileSizes(options.tileSizes.begin(), + options.tileSizes.begin() + + rootOp.getNumLoops()); + SmallVector rootInterchange = + options.tileInterchange.empty() + ? llvm::to_vector<6>(llvm::seq(0, rootOp.getNumLoops())) + : SmallVector(options.tileInterchange.begin(), + options.tileInterchange.begin() + + rootOp.getNumLoops()); + + // Check `rootInterchange` is a permutation of the `rootOp` loop dimensions. + // It has to be a permutation since the tiling cannot tile the same loop + // dimension multiple times. + if (!isPermutation(rootInterchange)) + return rewriter.notifyMatchFailure( + op, "expect the tile interchange permutes the root loops"); + + // Tile `rootOp` and fuse its producers. + FailureOr tileLoopNest = tileConsumerAndFuseProducers( + rewriter, rootOp, rootTileSizes, rootInterchange); + if (failed(tileLoopNest)) + return rewriter.notifyMatchFailure( + op, "tileConsumerAndFuseProducers failed unexpectedly"); + + // Replace all uses of the tiled loop operation. + rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults()); + + // Apply the filter if specified. + for (LinalgOp linalgOp : tileLoopNest->getAllTiledAndFusedOps()) + filter.replaceLinalgTransformationFilter(rewriter, linalgOp); + return failure(); +} + /// Linalg generic interchange pattern. mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern( MLIRContext *context, ArrayRef interchangeVector, diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -1,12 +1,13 @@ -// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=5,4,7 tile-interchange=1,0,2" -cse -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=MATMUL %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC %s -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> +// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// MATMUL-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 24)> +// MATMUL-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> -// CHECK: fuse_input -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// MATMUL: fuse_input +// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> builtin.func @fuse_input(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { @@ -18,31 +19,31 @@ %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill(%cst, %arg0) : f32, tensor<24x12xf32> -> tensor<24x12xf32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = - // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) - // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = - // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // MATMUL: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // MATMUL: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) // Tile both input operand dimensions. - // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]]) - // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV1]], %[[IV2]] - // CHECK-SAME: %[[UB1]], %[[UB2]] - // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) - // CHECK: %{{.*}} = linalg.matmul ins(%[[T1]] + // MATMUL: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]]) + // MATMUL: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // MATMUL-SAME: %[[IV1]], %[[IV2]] + // MATMUL-SAME: %[[UB1]], %[[UB2]] + // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: %{{.*}} = linalg.matmul ins(%[[T1]] %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> } // ----- -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (5, -d0 + 24)> +// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (4, -d0 + 25)> -// CHECK: fuse_output -// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +// MATMUL: fuse_output +// MATMUL-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> builtin.func @fuse_output(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { @@ -55,34 +56,34 @@ %0 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> // Update the iteration argument of the outermost tile loop. - // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] - // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] - // CHECK: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) - // CHECK: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]]) + // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // MATMUL: %[[TS1:.*]] = affine.min #[[MAP0]](%[[IV1]]) + // MATMUL: %[[TS0:.*]] = affine.min #[[MAP1]](%[[IV0]]) // Tile the both output operand dimensions. - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] - // CHECK-SAME: %[[IV1]], %[[IV0]] - // CHECK-SAME: %[[TS1]], %[[TS0]] - // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) - // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] - // CHECK: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] + // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // MATMUL-SAME: %[[IV1]], %[[IV0]] + // MATMUL-SAME: %[[TS1]], %[[TS0]] + // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // MATMUL: %{{.*}} = linalg.matmul {{.*}} outs(%[[ARG5]] %1 = linalg.matmul ins(%arg0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%0 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> } // ----- -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)> -// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> +// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (4, -d0 + 25)> +// MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (7, -d0 + 12)> +// MATMUL-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 25)> +// MATMUL-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, -d1 + 12)> #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d2)> -// CHECK: fuse_reduction -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> -// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> +// MATMUL: fuse_reduction +// MATMUL-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xf32> +// MATMUL-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x7x25xf32> builtin.func @fuse_reduction(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, @@ -98,23 +99,23 @@ linalg.yield %2 : f32 } -> tensor<12x25xf32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = - // CHECK: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) - // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = - // CHECK: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) - // CHECK: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) - // CHECK: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]]) + // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // MATMUL: %[[TS0:.*]] = affine.min #[[MAP0]](%[[IV0]]) + // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // MATMUL: %[[TS2:.*]] = affine.min #[[MAP1]](%[[IV2]]) + // MATMUL: %[[UB2:.*]] = affine.min #[[MAP3]](%[[TS2]], %[[IV2]]) + // MATMUL: %[[UB0:.*]] = affine.min #[[MAP2]](%[[TS0]], %[[IV0]]) // Tile only the parallel dimensions but not the reduction dimension. - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] - // CHECK-SAME: %[[IV2]], 0, %[[IV0]] - // CHECK-SAME: %[[UB2]], 7, %[[UB0]] - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] - // CHECK-SAME: %[[IV2]], %[[IV0]] - // CHECK-SAME: %[[UB2]], %[[UB0]] - // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] - // CHECK: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] + // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // MATMUL-SAME: %[[IV2]], 0, %[[IV0]] + // MATMUL-SAME: %[[UB2]], 7, %[[UB0]] + // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // MATMUL-SAME: %[[IV2]], %[[IV0]] + // MATMUL-SAME: %[[UB2]], %[[UB0]] + // MATMUL: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // MATMUL: %{{.*}} = linalg.matmul ins(%{{.*}}, %[[T2]] %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> } @@ -124,9 +125,9 @@ #map0 = affine_map<(d0, d1) -> (d1, d0)> #map1 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: fuse_transposed -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> -// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32> +// MATMUL: fuse_transposed +// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// MATMUL-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<12x24xf32> builtin.func @fuse_transposed(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>, @@ -142,26 +143,26 @@ linalg.yield %2 : f32 } -> tensor<24x12xf32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] = // Swap the input operand slice offsets due to the transposed indexing map. - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] - // CHECK-SAME: %[[IV2]], %[[IV1]] - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV1]], %[[IV2]] - // CHECK: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] - // CHECK: %{{.*}} = linalg.matmul ins(%[[T2]] + // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG3]] + // MATMUL-SAME: %[[IV2]], %[[IV1]] + // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // MATMUL-SAME: %[[IV1]], %[[IV2]] + // MATMUL: %[[T2:.*]] = linalg.generic {{.*}} ins(%[[T0]] {{.*}} outs(%[[T1]] + // MATMUL: %{{.*}} = linalg.matmul ins(%[[T2]] %1 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%arg2 : tensor<24x25xf32>) -> tensor<24x25xf32> return %1 : tensor<24x25xf32> } // ----- -// CHECK: fuse_input_and_output -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> -// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> +// MATMUL: fuse_input_and_output +// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<24x12xf32> +// MATMUL-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<24x25xf32> builtin.func @fuse_input_and_output(%arg0: tensor<24x12xf32>, %arg1: tensor<12x25xf32>, %arg2: tensor<24x25xf32>) -> tensor<24x25xf32> { @@ -175,27 +176,27 @@ %1 = linalg.fill(%cst, %arg2) : f32, tensor<24x25xf32> -> tensor<24x25xf32> // Fuse both producers to the appropriate tile loops. - // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] - // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] - // CHECK-SAME: %[[IV1]], %[[IV0]] - // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) - // CHECK: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] - // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV1]], %[[IV2]] - // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) - // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] + // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]] + // MATMUL: %[[T0:.*]] = tensor.extract_slice %[[ARG4]] + // MATMUL-SAME: %[[IV1]], %[[IV0]] + // MATMUL: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // MATMUL: scf.for %[[IV2:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[T1]] + // MATMUL: %[[T2:.*]] = tensor.extract_slice %[[ARG0]] + // MATMUL-SAME: %[[IV1]], %[[IV2]] + // MATMUL: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[ARG5]] %2 = linalg.matmul ins(%0, %arg1 : tensor<24x12xf32>, tensor<12x25xf32>) outs(%1 : tensor<24x25xf32>) -> tensor<24x25xf32> return %2 : tensor<24x25xf32> } // ----- -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> #map0 = affine_map<(d0, d1) -> (d1, d0)> -// CHECK: fuse_indexed -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32> +// MATMUL: fuse_indexed +// MATMUL-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<12x25xi32> builtin.func @fuse_indexed(%arg0: tensor<24x12xi32>, %arg1: tensor<12x25xi32>, %arg2: tensor<24x25xi32>) -> tensor<24x25xi32> { @@ -213,19 +214,19 @@ linalg.yield %9 : i32 } -> tensor<12x25xi32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = - // CHECK: scf.for %[[IV2:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV0:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV1:[0-9a-zA-Z]*]] = + // MATMUL: scf.for %[[IV2:[0-9a-zA-Z]*]] = // Shift the indexes by the slice offsets and swap the offsets due to the transposed indexing map. - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] - // CHECK-SAME: %[[IV2]], %[[IV0]] - // CHECK: linalg.generic {{.*}} outs(%[[T1]] - // CHECK: %[[IDX0:.*]] = linalg.index 0 - // CHECK: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]]) - // CHECK: %[[IDX1:.*]] = linalg.index 1 - // CHECK: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]]) - // CHECK: %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]] + // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // MATMUL-SAME: %[[IV2]], %[[IV0]] + // MATMUL: linalg.generic {{.*}} outs(%[[T1]] + // MATMUL: %[[IDX0:.*]] = linalg.index 0 + // MATMUL: %[[IDX0_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX0]], %[[IV0]]) + // MATMUL: %[[IDX1:.*]] = linalg.index 1 + // MATMUL: %[[IDX1_SHIFTED:.*]] = affine.apply #[[MAP0]](%[[IDX1]], %[[IV2]]) + // MATMUL: %{{.*}} = arith.addi %[[IDX0_SHIFTED]], %[[IDX1_SHIFTED]] %1 = linalg.matmul ins(%arg0, %0 : tensor<24x12xi32>, tensor<12x25xi32>) outs(%arg2 : tensor<24x25xi32>) -> tensor<24x25xi32> return %1 : tensor<24x25xi32> } @@ -235,28 +236,28 @@ #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> -// CHECK: fuse_outermost_reduction -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32> +// GENERIC: fuse_outermost_reduction +// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> +// GENERIC-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32> func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32> // Cannot fuse the output fill since the reduction loop is the outermost loop. - // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]]) + // GENERIC: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]]) %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] + // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] - // Check the input fill has been fused. - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV1]], %[[IV0]] - // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) - // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] - // CHECK-SAME: %[[IV1]] - // CHECK: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] + // MATMUL the input fill has been fused. + // GENERIC: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // GENERIC-SAME: %[[IV1]], %[[IV0]] + // GENERIC: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // GENERIC: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] + // GENERIC-SAME: %[[IV1]] + // GENERIC: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) { ^bb0(%arg2: f32, %arg3: f32): // no predecessors %3 = arith.addf %arg2, %arg3 : f32 @@ -267,39 +268,39 @@ // ----- -// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)> -// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)> +// GENERIC-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// GENERIC-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)> +// GENERIC-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)> #map0 = affine_map<(d0, d1) -> (d0, d0 + d1)> #map1 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK: fuse_non_rectangular -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> +// GENERIC: fuse_non_rectangular +// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> func @fuse_non_rectangular(%arg0: tensor<10x17xf32>, %arg1: tensor<10x8xf32>) -> tensor<10x8xf32> { - // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index - // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index - // CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index - // CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index - // CHECK-DAG: %[[C10:.*]] = arith.constant 10 : index + // GENERIC-DAG: %[[C0:.*]] = arith.constant 0 : index + // GENERIC-DAG: %[[C4:.*]] = arith.constant 4 : index + // GENERIC-DAG: %[[C5:.*]] = arith.constant 5 : index + // GENERIC-DAG: %[[C8:.*]] = arith.constant 8 : index + // GENERIC-DAG: %[[C10:.*]] = arith.constant 10 : index %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32> - // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]] - // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]] + // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]] + // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]] // Compute producer on a hyper rectangular bounding box. Along the second dimenson, // the offset is set to the sum of the induction variables, and the upper bound // to either 8 (tile size) or 17 (sum of max indices (9+7) then + 1) minus the // induction variables. - // CHECK: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]] - // CHECK: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]] - // CHECK: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]] - // CHECK: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV1]], %[[SUM]] - // CHECK-SAME: , %[[UB1]] - // CHECK: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) + // GENERIC-DAG: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]] + // GENERIC-DAG: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]] + // GENERIC-DAG: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]] + // GENERIC: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + // GENERIC-SAME: %[[IV1]], %[[SUM]] + // GENERIC-SAME: , %[[UB1]] + // GENERIC: %[[T1:.*]] = linalg.fill(%{{.*}}, %[[T0]]) %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) { ^bb0(%arg2: f32, %arg3: f32): // no predecessors %2 = arith.addf %arg2, %arg3 : f32 diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-sequence-on-tensors.mlir @@ -1,11 +1,12 @@ -// RUN: mlir-opt %s -linalg-tile-and-fuse-tensor-ops="tile-sizes=4,4,0,0 tile-interchange=0,1,2,3" -cse --canonicalize -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.conv_2d fuse tile-sizes=4,4,0,0 tile-interchange=0,1,2,3 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=CONV %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=4,4,0 tile-interchange=0,1,2 run-enable-pass=false" -split-input-file | FileCheck --check-prefix=MATMUL %s -// CHECK: fuse_conv_chain -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32> -// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32> -// CHECK-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32> -// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32> -// CHECK-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32> +// CONV: fuse_conv_chain +// CONV-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<2x2xf32> +// CONV-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<11x11xf32> +// CONV-SAME: %[[ARG2:[0-9a-zA-Z]*]]: tensor<10x10xf32> +// CONV-SAME: %[[ARG3:[0-9a-zA-Z]*]]: tensor<9x9xf32> +// CONV-SAME: %[[ARG4:[0-9a-zA-Z]*]]: tensor<8x8xf32> builtin.func @fuse_conv_chain(%arg0: tensor<2x2xf32>, %arg1: tensor<11x11xf32>, %arg2: tensor<10x10xf32>, @@ -14,34 +15,34 @@ %cst = arith.constant 1.0 : f32 // Do not tile the filter fill since the filter dimensions are not tiled. - // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + // CONV: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) %0 = linalg.fill(%cst, %arg0) : f32, tensor<2x2xf32> -> tensor<2x2xf32> // Fuse all other operations. - // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]] - // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]] + // CONV: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG5:.*]] = %[[ARG4]] + // CONV: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG6:.*]] = %[[ARG5]] - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] - // CHECK-SAME: %[[IV0]], %[[IV1]] - // CHECK: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] - // CHECK-SAME: %[[IV0]], %[[IV1]] - // CHECK: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) - // CHECK: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]] + // CONV: %[[T1:.*]] = tensor.extract_slice %[[ARG1]] + // CONV-SAME: %[[IV0]], %[[IV1]] + // CONV: %[[T2:.*]] = tensor.extract_slice %[[ARG2]] + // CONV-SAME: %[[IV0]], %[[IV1]] + // CONV: %[[T3:.*]] = linalg.fill(%{{.*}}, %[[T2]]) + // CONV: %[[T4:.*]] = linalg.conv_2d ins(%[[T1]], %[[T0]] : {{.*}} outs(%[[T3]] %1 = linalg.fill(%cst, %arg2) : f32, tensor<10x10xf32> -> tensor<10x10xf32> %2 = linalg.conv_2d ins(%arg1, %0 : tensor<11x11xf32>, tensor<2x2xf32>) outs(%1 : tensor<10x10xf32>) -> tensor<10x10xf32> - // CHECK: %[[T5:.*]] = tensor.extract_slice %[[ARG3]] - // CHECK-SAME: %[[IV0]], %[[IV1]] - // CHECK: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]]) - // CHECK: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]] + // CONV: %[[T5:.*]] = tensor.extract_slice %[[ARG3]] + // CONV-SAME: %[[IV0]], %[[IV1]] + // CONV: %[[T6:.*]] = linalg.fill(%{{.*}}, %[[T5]]) + // CONV: %[[T7:.*]] = linalg.conv_2d ins(%[[T4]], %[[T0]] : {{.*}} outs(%[[T6]] %3 = linalg.fill(%cst, %arg3) : f32, tensor<9x9xf32> -> tensor<9x9xf32> %4 = linalg.conv_2d ins(%2, %0 : tensor<10x10xf32>, tensor<2x2xf32>) outs(%3 : tensor<9x9xf32>) -> tensor<9x9xf32> // Use the argument passed in by iteration argument. - // CHECK: %[[T8:.*]] = tensor.extract_slice %[[ARG6]] - // CHECK-SAME: %[[IV0]], %[[IV1]] - // CHECK: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]]) - // CHECK: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]] + // CONV: %[[T8:.*]] = tensor.extract_slice %[[ARG6]] + // CONV-SAME: %[[IV0]], %[[IV1]] + // CONV: %[[T9:.*]] = linalg.fill(%{{.*}}, %[[T8]]) + // CONV: %[[T5:.*]] = linalg.conv_2d ins(%[[T7]], %[[T0]] {{.*}} outs(%[[T9]] %5 = linalg.fill(%cst, %arg4) : f32, tensor<8x8xf32> -> tensor<8x8xf32> %6 = linalg.conv_2d ins(%4, %0 : tensor<9x9xf32>, tensor<2x2xf32>) outs(%5 : tensor<8x8xf32>) -> tensor<8x8xf32> return %6 : tensor<8x8xf32> @@ -49,8 +50,8 @@ // ----- -// CHECK: fuse_matmul_chain -// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32> +// MATMUL: fuse_matmul_chain +// MATMUL-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<8x8xf32> builtin.func @fuse_matmul_chain(%arg0: tensor<8x8xf32>) -> tensor<8x8xf32> { %c0 = arith.constant 0 : index %c12 = arith.constant 12 : index @@ -60,24 +61,24 @@ %cst = arith.constant 0.000000e+00 : f32 // Do not tile rhs fill of the producer matmul since none of its loop dimension is tiled. - // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) + // MATMUL: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG0]]) %0 = linalg.fill(%cst, %arg0) : f32, tensor<8x8xf32> -> tensor<8x8xf32> - // CHECK: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]] - // CHECK: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]] + // MATMUL: scf.for %[[IV0:.*]] = {{.*}} iter_args(%[[ARG1:.*]] = %[[ARG0]] + // MATMUL: scf.for %[[IV1:.*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[ARG1]] // Only the outermost loop of the producer matmul is tiled. - // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] - // CHECK-SAME: %[[IV0]], 0 - // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) - // CHECK: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}} + // MATMUL: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // MATMUL-SAME: %[[IV0]], 0 + // MATMUL: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // MATMUL: %[[T3:.*]] = linalg.matmul ins(%[[T2]], %[[T0]] {{.*}} %1 = linalg.matmul ins(%0, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> // Use the argument passed in by iteration argument. - // CHECK: %[[T4:.*]] = tensor.extract_slice %[[ARG2]] - // CHECK-SAME: %[[IV0]], %[[IV1]] - // CHECK: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]]) - // CHECK: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]] + // MATMUL: %[[T4:.*]] = tensor.extract_slice %[[ARG2]] + // MATMUL-SAME: %[[IV0]], %[[IV1]] + // MATMUL: %[[T5:.*]] = linalg.fill(%{{.*}}, %[[T4]]) + // MATMUL: %{{.*}} = linalg.matmul ins(%[[T3]], {{.*}} outs(%[[T5]] %2 = linalg.matmul ins(%1, %0 : tensor<8x8xf32>, tensor<8x8xf32>) outs(%0 : tensor<8x8xf32>) -> tensor<8x8xf32> return %2 : tensor<8x8xf32> } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgCodegenStrategy.cpp @@ -52,16 +52,21 @@ void runOnFunction() override; - void runStrategy(LinalgTilingOptions tilingOptions, + void runStrategy(LinalgTilingAndFusionOptions tilingAndFusionOptions, + LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, LinalgPaddingOptions paddingOptions, vector::VectorContractLowering vectorContractLowering, vector::VectorTransferSplit vectorTransferSplit); + Option fuse{ + *this, "fuse", + llvm::cl::desc("Fuse the producers after tiling the root op."), + llvm::cl::init(false)}; ListOption tileSizes{*this, "tile-sizes", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the tile sizes.")}; - ListOption tileInterchange{ + ListOption tileInterchange{ *this, "tile-interchange", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Specifies the tile interchange.")}; @@ -148,6 +153,7 @@ }; void TestLinalgCodegenStrategy::runStrategy( + LinalgTilingAndFusionOptions tilingAndFusionOptions, LinalgTilingOptions tilingOptions, LinalgTilingOptions registerTilingOptions, LinalgPaddingOptions paddingOptions, @@ -156,7 +162,10 @@ assert(!anchorOpName.empty()); CodegenStrategy strategy; StringRef genericOpName = GenericOp::getOperationName(); - strategy.tileIf(!tileSizes.empty(), anchorOpName, tilingOptions) + strategy + .tileAndFuseIf(fuse && !tileSizes.empty(), anchorOpName, + tilingAndFusionOptions) + .tileIf(!fuse && !tileSizes.empty(), anchorOpName, tilingOptions) .promoteIf(promote, anchorOpName, LinalgPromotionOptions() .setAlignment(16) @@ -204,11 +213,17 @@ if (!anchorFuncOpName.empty() && anchorFuncOpName != getFunction().getName()) return; + LinalgTilingAndFusionOptions tilingAndFusionOptions; + tilingAndFusionOptions.tileSizes = {tileSizes.begin(), tileSizes.end()}; + tilingAndFusionOptions.tileInterchange = {tileInterchange.begin(), + tileInterchange.end()}; + LinalgTilingOptions tilingOptions; if (!tileSizes.empty()) tilingOptions = tilingOptions.setTileSizes(tileSizes); if (!tileInterchange.empty()) - tilingOptions = tilingOptions.setInterchange(tileInterchange); + tilingOptions = tilingOptions.setInterchange( + SmallVector(tileInterchange.begin(), tileInterchange.end())); LinalgTilingOptions registerTilingOptions; if (!registerTileSizes.empty()) @@ -245,8 +260,8 @@ .Case("vector-transfers", vector::VectorTransferSplit::VectorTransfer) .Default(vector::VectorTransferSplit::None); - runStrategy(tilingOptions, registerTilingOptions, paddingOptions, - vectorContractLowering, vectorTransferSplit); + runStrategy(tilingAndFusionOptions, tilingOptions, registerTilingOptions, + paddingOptions, vectorContractLowering, vectorTransferSplit); } namespace mlir {