diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -615,6 +615,12 @@ return getBody()->getArguments().take_back(outputs().size()); } + void setDistributionTypes(Builder& b, ArrayRef types) { + assert(types.size() == getNumLoops() && + "expected distribution type for every dimension"); + distribution_typesAttr(b.getStrArrayAttr(types)); + } + void setLowerBounds(ValueRange lowerBounds) { unsigned numLoops = getNumLoops(); assert(lowerBounds.size() == numLoops && diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -30,7 +30,8 @@ createLinalgTilingToParallelLoopsPass(ArrayRef tileSizes = {}); std::unique_ptr> -createLinalgTilingToTiledLoopPass(ArrayRef tileSizes = {}); +createLinalgTilingToTiledLoopPass(ArrayRef tileSizes = {}, + ArrayRef distributionTypes = {}); std::unique_ptr> createLinalgPromotionPass(bool dynamicBuffers, bool useAlloca); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -182,6 +182,9 @@ let constructor = "mlir::createLinalgTilingToTiledLoopPass()"; let options = [ ListOption<"tileSizes", "linalg-tile-sizes", "int64_t", "Tile sizes", + "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">, + ListOption<"distributionTypes", "linalg-distribution-types", "std::string", + "DistributionTypes", "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated"> ]; let dependentDialects = [ diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -493,6 +493,14 @@ return *this; } + /// Specification markers of how to distribute the `linalg.tiled_loop`. + SmallVector distributionTypes = {}; + + LinalgTilingOptions &setDistributionTypes(ArrayRef types) { + distributionTypes.assign(types.begin(), types.end()); + return *this; + } + /// Computation function that returns a padding value to use when padding to /// force static sizes. When `paddingValueComputationFunction` is set, padding /// operations are introduced, that guarantee the underlying op is statically diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -247,7 +247,8 @@ function_ref bodyBuilderFn, - Optional = None); + Optional = None, + ArrayRef distributionTypes = {}); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -278,7 +278,8 @@ return scf::ValueVector(tensorResults.begin(), tensorResults.end()); }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, - tiledLoopBodyBuilder, options.distribution); + tiledLoopBodyBuilder, options.distribution, + options.distributionTypes); // 3. Transform IndexOp results w.r.t. the tiling. transformIndexOps(b, res, ivs, loopIndexToRangeIndex); @@ -428,11 +429,14 @@ >::insert(patterns, options); } -static void applyTilingToLoopPatterns(LinalgTilingLoopType loopType, - FuncOp funcOp, - ArrayRef tileSizes) { - auto options = - LinalgTilingOptions().setTileSizes(tileSizes).setLoopType(loopType); +static void +applyTilingToLoopPatterns(LinalgTilingLoopType loopType, FuncOp funcOp, + ArrayRef tileSizes, + ArrayRef distributionTypes = {}) { + auto options = LinalgTilingOptions() + .setTileSizes(tileSizes) + .setLoopType(loopType) + .setDistributionTypes(distributionTypes); MLIRContext *ctx = funcOp.getContext(); RewritePatternSet patterns(ctx); insertTilingPatterns(patterns, options); @@ -472,11 +476,19 @@ struct LinalgTilingToTiledLoopsPass : public LinalgTilingToTiledLoopsBase { LinalgTilingToTiledLoopsPass() = default; - LinalgTilingToTiledLoopsPass(ArrayRef sizes) { tileSizes = sizes; } + LinalgTilingToTiledLoopsPass(ArrayRef sizes, + ArrayRef types) { + tileSizes = sizes; + distributionTypes = llvm::to_vector<2>( + llvm::map_range(types, [](StringRef ref) { return ref.str(); })); + } void runOnFunction() override { - applyTilingToLoopPatterns(LinalgTilingLoopType::TiledLoops, getFunction(), - tileSizes); + applyTilingToLoopPatterns( + LinalgTilingLoopType::TiledLoops, getFunction(), tileSizes, + llvm::to_vector<2>( + llvm::map_range(distributionTypes, + [](std::string &str) { return StringRef(str); }))); } }; @@ -493,6 +505,8 @@ } std::unique_ptr> -mlir::createLinalgTilingToTiledLoopPass(ArrayRef tileSizes) { - return std::make_unique(tileSizes); +mlir::createLinalgTilingToTiledLoopPass(ArrayRef tileSizes, + ArrayRef distributionTypes) { + return std::make_unique(tileSizes, + distributionTypes); } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -198,7 +198,8 @@ function_ref bodyBuilderFn, - Optional distributionOptions) { + Optional distributionOptions, + ArrayRef distributionTypes) { auto iterArgInitValues = linalgOp.getOutputTensors(); // Create procInfo so it dominates loops, if appropriate. SmallVector procInfo; @@ -246,7 +247,7 @@ function_ref bodyBuilderFn, - Optional) { + Optional, ArrayRef) { auto iterArgInitValues = linalgOp.getOutputTensors(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; @@ -275,7 +276,8 @@ function_ref bodyBuilderFn, - Optional) { + Optional distributionOptions, + ArrayRef distributionTypes) { SmallVector procInfo; SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -291,6 +293,8 @@ auto tiledLoop = b.create( loc, lbs, ubs, steps, linalgOp.getInputs(), linalgOp.getOutputs(), b.getArrayAttr(iteratorTypes), wrappedBuilderFn); + if (!distributionTypes.empty()) + tiledLoop.setDistributionTypes(b, distributionTypes); // Replace inputs/outputs with the corresponding region args. auto isInsideTiledLoop = [&](OpOperand &operand) { @@ -446,7 +450,8 @@ function_ref bodyBuilderFn, - Optional distributionOptions) { + Optional distributionOptions, + ArrayRef distributionTypes) { auto iterArgInitValues = linalgOp.getOutputTensors(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. diff --git a/mlir/test/Dialect/Linalg/tile-tensors.mlir b/mlir/test/Dialect/Linalg/tile-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-tensors.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -linalg-tile="linalg-tile-sizes=2,3,4" -split-input-file | FileCheck %s -// RUN: mlir-opt %s -linalg-tile-to-tiled-loop="linalg-tile-sizes=2,3,4" -split-input-file | FileCheck %s -check-prefix=TLOOP +// RUN: mlir-opt %s -linalg-tile-to-tiled-loop="linalg-tile-sizes=2,3,4 linalg-distribution-types=block_x,block_y,none" -split-input-file | FileCheck %s -check-prefix=TLOOP // CHECK-LABEL: func @matmul_tensors( // CHECK-SAME: %[[TA:[0-9a-z]+]]: tensor @@ -48,7 +48,8 @@ // TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]]) // TLOOP-SAME: ins (%[[A0:.*]] = %[[ARG_0]]: [[TY]], %[[A1:.*]] = %[[ARG_1]]: [[TY]]) // TLOOP-SAME: outs (%[[A2:.*]] = %[[ARG_2]]: [[TY]]) -// TLOOP-SAME: iterators["parallel", "parallel", "reduction"] { +// TLOOP-SAME: iterators["parallel", "parallel", "reduction"] +// TLOOP-SAME: distribution["block_x", "block_y", "none"] { // TLOOP: %[[SUB_ARG_0:.*]] = subtensor %[[A0]][%[[I]], %[[K]]] // TLOOP: %[[SUB_ARG_1:.*]] = subtensor %[[A1]][%[[K]], %[[J]]] @@ -128,6 +129,7 @@ // TLOOP-SAME: step (%[[C2]], %[[C3]], %[[C4]]) // TLOOP-SAME: ins (%{{.*}} = %[[ARG_0]]: [[TY]], %{{.*}} = %[[ARG_1]]: [[TY]]) // TLOOP-SAME: outs (%{{.*}} = %[[INIT]]: [[TY]]) +// TLOOP-SAME: distribution["block_x", "block_y", "none"] { // -----