diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -635,8 +635,20 @@ struct LinalgTilingAndFusionOptions { /// Tile sizes used to tile the root operation. SmallVector tileSizes; + LinalgTilingAndFusionOptions &setTileSizes(ArrayRef ts) { + tileSizes.assign(ts.begin(), ts.end()); + return *this; + } /// Tile interchange used to permute the tile loops. SmallVector tileInterchange; + /// When specified, specifies distribution of generated tile loops to + /// processors. + Optional tileDistribution = None; + LinalgTilingAndFusionOptions & + setDistributionOptions(LinalgLoopDistributionOptions distributionOptions) { + tileDistribution = std::move(distributionOptions); + return *this; + } }; struct LinalgTilingOptions { diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -245,73 +245,6 @@ OpResult producerOpResult, OpOperand &consumerOpOperand); -//===----------------------------------------------------------------------===// -// Fusion on tensor utilities -//===----------------------------------------------------------------------===// - -/// A struct to manage the tile loop nest specific information. -class TileLoopNest { -public: - TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {} - - /// Tile the root operation using the given `tileSizes` and `tileInterchange`. - LogicalResult tileRootOp(OpBuilder &b, ArrayRef tileSizes, - ArrayRef tileInterchange); - - /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns - /// the fused producer or fails if fusion is not possible. - FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); - - /// Returns the replacement results for the original untiled root operation. - ValueRange getRootOpReplacementResults(); - - /// Returns the tiled root operation. - LinalgOp getRootOp() { return rootOp; } - - /// Returns the tiled root operation and the fused producers. - SmallVector getAllTiledAndFusedOps(); - - /// Returns the loop ops generated from tiling. - ArrayRef getLoopOps() { return tileLoopOps; } - - /// Returns true if the tile loop nest has no tile loops. - bool isEmpty(); - -private: - /// Returns true if the tile loop nest invariants are satisfied: - /// - The `rootOp` has been tiled at least once. - /// - The number of tile loop operations and dimensions match. - /// - The innermost tile loop is the parent of `tiledOp`. - /// - The tile loops are directly nested. - // TODO: relax to support additional control flow, e.g., IfOp. - bool isValid(); - - /// Searches the block arguments tied to a block argument `bbArg` of the - /// innermost tile loop. Returns the block argument from outermost to - /// innermost or an empty vector if none are found. - SmallVector getTiedBBArgs(BlockArgument bbArg); - - /// Returns the iteration argument of the outermost tile loop mapped to a - /// block argument `bbArg` of the innermost tile loop. - OpOperand *getTiedIterArg(BlockArgument bbArg); - - /// Returns true if `bbArg` has other used than `sliceOp` and its - /// dependencies. Only if there are no other uses, the producer output - /// iteration argument may reused to pass the producer result after fusion. - bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); - - LinalgOp rootOp; - SmallVector tileLoopOps; - DenseMap> tiledRootAndFusedOpsLoops; -}; - -/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the -/// `tileSizes` and `tileInterchange` parameters to control the tiling. -FailureOr -tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, - ArrayRef tileSizes, - ArrayRef tileInterchange); - //===----------------------------------------------------------------------===// // Distribution utilities //===----------------------------------------------------------------------===// @@ -396,6 +329,77 @@ Value procId, Value nprocs, Value &lb, Value &ub, Value &step); +//===----------------------------------------------------------------------===// +// Fusion on tensor utilities +//===----------------------------------------------------------------------===// + +/// A struct to manage the tile loop nest specific information. +class TileLoopNest { +public: + TileLoopNest(LinalgOp rootOp) : rootOp(rootOp) {} + + /// Tile the root operation using the given `tileSizes` and `tileInterchange`, + /// and `tileDistribution`. + LogicalResult + tileRootOp(OpBuilder &b, ArrayRef tileSizes, + ArrayRef tileInterchange, + Optional tileDistribution); + + /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns + /// the fused producer or fails if fusion is not possible. + FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); + + /// Returns the replacement results for the original untiled root operation. + ValueRange getRootOpReplacementResults(); + + /// Returns the tiled root operation. + LinalgOp getRootOp() { return rootOp; } + + /// Returns the tiled root operation and the fused producers. + SmallVector getAllTiledAndFusedOps(); + + /// Returns the loop ops generated from tiling. + ArrayRef getLoopOps() { return tileLoopOps; } + + /// Returns true if the tile loop nest has no tile loops. + bool isEmpty(); + +private: + /// Returns true if the tile loop nest invariants are satisfied: + /// - The `rootOp` has been tiled at least once. + /// - The number of tile loop operations and dimensions match. + /// - The innermost tile loop is the parent of `tiledOp`. + /// - The tile loops are directly nested. + // TODO: relax to support additional control flow, e.g., IfOp. + bool isValid(); + + /// Searches the block arguments tied to a block argument `bbArg` of the + /// innermost tile loop. Returns the block argument from outermost to + /// innermost or an empty vector if none are found. + SmallVector getTiedBBArgs(BlockArgument bbArg); + + /// Returns the iteration argument of the outermost tile loop mapped to a + /// block argument `bbArg` of the innermost tile loop. + OpOperand *getTiedIterArg(BlockArgument bbArg); + + /// Returns true if `bbArg` has other used than `sliceOp` and its + /// dependencies. Only if there are no other uses, the producer output + /// iteration argument may reused to pass the producer result after fusion. + bool hasOtherUses(BlockArgument bbArg, tensor::ExtractSliceOp sliceOp); + + LinalgOp rootOp; + SmallVector tileLoopOps; + DenseMap> tiledRootAndFusedOpsLoops; +}; + +/// Tiles `consumerOp` and fuses its dependencies if possible. Uses the +/// `tileSizes`, `tileInterchange`, and `tileDistribution` parameters to control +/// the tiling. +FailureOr tileConsumerAndFuseProducers( + OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, + ArrayRef tileInterchange, + Optional tileDistribution); + //===----------------------------------------------------------------------===// // Generic op region utilities //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -269,9 +269,10 @@ }); } -LogicalResult TileLoopNest::tileRootOp(OpBuilder &b, - ArrayRef tileSizes, - ArrayRef tileInterchange) { +LogicalResult TileLoopNest::tileRootOp( + OpBuilder &b, ArrayRef tileSizes, + ArrayRef tileInterchange, + Optional tileDistribution) { // Exit if all tile sizes are zero. if (tileSizes.size() == static_cast(count(tileSizes, 0))) return success(); @@ -283,6 +284,9 @@ tileInterchange.begin(), tileInterchange.end())) .setTileSizes(tileSizes) .setLoopType(LinalgTilingLoopType::Loops); + if (tileDistribution) + tilingOptions = + tilingOptions.setDistributionOptions(tileDistribution.getValue()); // TODO: Propagate RewriterBase everywhere. IRRewriter rewriter(b); @@ -408,10 +412,10 @@ // Tile and fuse entry-points. //===----------------------------------------------------------------------===// -FailureOr -mlir::linalg::tileConsumerAndFuseProducers(OpBuilder &b, LinalgOp consumerOp, - ArrayRef tileSizes, - ArrayRef tileInterchange) { +FailureOr mlir::linalg::tileConsumerAndFuseProducers( + OpBuilder &b, LinalgOp consumerOp, ArrayRef tileSizes, + ArrayRef tileInterchange, + Optional tileDistribution) { assert(tileSizes.size() == tileInterchange.size() && "expect the number of tile sizes and interchange dims to match"); assert(isPermutation(tileInterchange) && @@ -446,7 +450,8 @@ SmallVector outerTileSizes; outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); outerTileSizes.append(tileSizes.size() - split, 0); - if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange))) + if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, + tileDistribution))) return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); @@ -454,7 +459,8 @@ SmallVector innerTileSizes; innerTileSizes.append(split, 0); innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); - if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange))) + if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, + tileDistribution))) return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getInputOperands()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -613,8 +613,9 @@ op, "expect the tile interchange permutes the root loops"); // Tile `rootOp` and fuse its producers. - FailureOr tileLoopNest = tileConsumerAndFuseProducers( - rewriter, rootOp, rootTileSizes, rootInterchange); + FailureOr tileLoopNest = + tileConsumerAndFuseProducers(rewriter, rootOp, rootTileSizes, + rootInterchange, options.tileDistribution); if (failed(tileLoopNest)) return rewriter.notifyMatchFailure( op, "tileConsumerAndFuseProducers failed unexpectedly"); diff --git a/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-fuse-and-distribute.mlir @@ -0,0 +1,53 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-fuse-and-distribute-options -split-input-file | FileCheck %s + +// CHECK: #[[MULMAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)> +// CHECK: #[[ADDMAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func @fill_matmul_tensors( +// CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor +// CHECK-SAME: %[[TB:[0-9a-z]+]]: tensor) -> tensor { +func @fill_matmul_tensors( + %arg0: tensor, %arg1: tensor) + -> tensor { +// CHECK-DAG: %[[C8:.*]] = arith.constant 8 : index +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[BIDY:.*]] = gpu.block_id y +// CHECK-DAG: %[[NBLOCKSY:.*]] = gpu.grid_dim y +// CHECK-DAG: %[[BIDX:.*]] = gpu.block_id x +// CHECK-DAG: %[[NBLOCKSX:.*]] = gpu.grid_dim x +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDY]], %[[C8]]] +// CHECK: %[[LBY:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] +// CHECK: %[[STEPY:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSY]], %[[C8]]] +// CHECK: %[[TD0:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC0:.*]] = %[[INIT]]) -> (tensor) { +// CHECK: %[[MUL:.+]] = affine.apply #[[MULMAP]]()[%[[BIDX]], %[[C8]]] +// CHECK: %[[LBX:.+]] = affine.apply #[[ADDMAP]]()[%[[MUL]], %[[C0]]] +// CHECK: %[[STEPX:.+]] = affine.apply #[[MULMAP]]()[%[[NBLOCKSX]], %[[C8]]] +// CHECK: %[[TD1:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC1:.*]] = %[[TC0]]) -> (tensor) { +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[TC1]] +// CHECK: %[[FILL:.+]] = linalg.fill(%{{.+}}, %[[SLICE]]) +// CHECK: %[[sTD2:.*]] = scf.for {{.*}} to {{.*}} step {{.*}} iter_args(%[[TC2:.*]] = %[[FILL]]) -> (tensor) { +// CHECK: %[[sTA:.*]] = tensor.extract_slice %[[TA]][{{.*}}] : tensor to tensor +// CHECK: %[[sTB:.*]] = tensor.extract_slice %[[TB]][{{.*}}] : tensor to tensor +// CHECK: %[[sTC:.*]] = tensor.extract_slice %[[TC2]][{{.*}}] : tensor to tensor +// CHECK: %[[sTD:.*]] = linalg.matmul ins(%[[sTA]], %[[sTB]] : tensor, tensor) +// CHECK-SAME: outs(%[[sTC]] : tensor) -> tensor +// CHECK: %[[TD:.*]] = tensor.insert_slice %[[sTD]] into %[[TC2]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD]] : tensor +// CHECK: %[[TD2:.*]] = tensor.insert_slice %[[sTD2]] into %[[TC1]][{{.*}}] : tensor into tensor +// CHECK: scf.yield %[[TD2]] : tensor +// CHECK: scf.yield %[[TD1]] : tensor + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %cst = arith.constant 0.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg1, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.fill(%cst, %2) : f32, tensor -> tensor + %4 = linalg.matmul {__internal_linalg_transform__ = "tensors_fuse_distribute1"} + ins(%arg0, %arg1: tensor, tensor) + outs(%3: tensor) + -> tensor + +// CHECK: return %[[TD0]] : tensor + return %4 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -78,6 +78,10 @@ *this, "test-tile-and-distribute-options", llvm::cl::desc("Test tile and distribute options"), llvm::cl::init(false)}; + Option testTileFuseAndDistributionOptions{ + *this, "test-tile-fuse-and-distribute-options", + llvm::cl::desc("Test tile, fuse and distribute options"), + llvm::cl::init(false)}; Option testVectorTransferForwardingPatterns{ *this, "test-vector-transfer-forwarding-patterns", llvm::cl::desc( @@ -505,6 +509,21 @@ } } +static void fillTileFuseAndDistributePatterns(MLIRContext *context, + RewritePatternSet &patterns) { + LinalgLoopDistributionOptions cyclicNprocsEqNiters; + cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); + cyclicNprocsEqNiters.procInfo = getGpuProcIds; + patterns.add( + MatmulOp::getOperationName(), context, + LinalgTilingAndFusionOptions() + .setTileSizes({8, 8, 4}) + .setDistributionOptions(cyclicNprocsEqNiters), + LinalgTransformationFilter( + StringAttr::get(context, "tensors_fuse_distribute1"), + StringAttr::get(context, "tensors_after_fuse_distribute1"))); +} + static void applyMatmulToVectorPatterns(FuncOp funcOp, bool testMatmulToVectorPatterns1dTiling, @@ -698,6 +717,12 @@ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); return; } + if (testTileFuseAndDistributionOptions) { + RewritePatternSet patterns(&getContext()); + fillTileFuseAndDistributePatterns(&getContext(), patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + return; + } if (testPatterns) return applyPatterns(getOperation()); if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)