diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -270,6 +270,15 @@ loopType = lt; return *this; } + + /// When specified, specifies distribution of generated tile loops to + /// processors. + Optional distribution = None; + LinalgTilingOptions & + setDistributionOptions(LinalgLoopDistributionOptions &distributionOptions) { + distribution = distributionOptions; + return *this; + } }; /// Canonicalization patterns relevant to apply after tiling patterns. These are diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -156,6 +156,70 @@ inVec = auxVec; } +/// Scheme used to distribute loops to processors. +enum class DistributionMethod { + /// Cyclic distribution where no assumption is made about the dynamic + /// relationship between number of processors and number of iterations of the + /// distributed loop. Distributes the following loop + /// + /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) + /// + /// to + /// + /// scf.parallel(%iv)= (%lb + %procId * %step) to (%ub) step (%step * %nprocs) + Cyclic = 0, + + /// Cyclic distribtuion where the number of processors can be assumed to be + /// more than or equal to the number of iterations of the distributed loop. In + /// such cases, a simple in-bounds check is enough (instead of materializing a + /// loop). Distributes the following loop + /// + /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) + /// + /// to + /// + /// %iv = %lb + %procId * %step + /// %cond = cmpi "slt", %iv, %ub + /// scf.if %cond { + /// ... + /// } + CyclicNumProcsGeNumIters = 1, + + /// Cyclic distribtuion where the number of processors can be assumed to be + /// equal to the number of iterations of the distributed loop. In such cases, + /// no bounds check is needed. Distributes the following loop + /// + /// scf.parallel (%iv) = (%lb) to (%ub) step (%step) + /// + /// to + /// + /// %iv = %lb + %procId * %step + CyclicNumProcsEqNumIters = 2 +}; + +/// Callback function type used to get processor ID, and number of processors +/// used for distribution. +struct ProcInfo { + Value procId; + Value nprocs; +}; +using ProcInfoCallBackFn = + std::function; + +/// Options that allow distribution of loops generated in Linalg transforms to +/// processors while generating the loops. +struct LinalgLoopDistributionOptions { + /// Callback function that returns the Value for processor ID, and number of + /// processors used to execute a given loop. + ProcInfoCallBackFn procInfo; + /// Specification of how to distribute the `scf.parallel` loops that are + /// generated. As the `scf.parallel` loop is generated, the elements of this + /// vector is used (from left to right) and the specified distribution is + /// applied. If the vector is less than the number of `scf.parallel` loops + /// generated, then no distribution is applied. + SmallVector distributionMethod = {}; +}; + /// Utility class used to generate nested loops with ranges described by /// `loopRanges` and loop type described by the `iteratorTypes`. `bodyBuilderFn` /// is used to generate the body of the innermost loop. It is passed a range @@ -168,7 +232,8 @@ static void doit(ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn); + function_ref bodyBuilderFn, + Optional = None); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -382,7 +382,8 @@ if (!options.interchangeVector.empty()) applyPermutationToVector(iteratorTypes, options.interchangeVector); GenerateLoopNest::doit( - loopRanges, iteratorTypes, [&](ValueRange localIvs) { + loopRanges, iteratorTypes, + [&](ValueRange localIvs) { auto &b = ScopedContext::getBuilderRef(); auto loc = ScopedContext::getLocation(); ivs.assign(localIvs.begin(), localIvs.end()); @@ -401,7 +402,8 @@ auto operands = getAssumedNonViewOperands(op); views.append(operands.begin(), operands.end()); res = op.clone(b, loc, views); - }); + }, + options.distribution); // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); @@ -410,8 +412,14 @@ SmallVector loops; loops.reserve(ivs.size()); for (auto iv : ivs) { - loops.push_back(iv.cast().getOwner()->getParentOp()); - assert(loops.back() && "no owner found for induction variable!"); + if (iv.isa()) { + loops.push_back(iv.cast().getOwner()->getParentOp()); + assert(loops.back() && "no owner found for induction variable!"); + } else { + // TODO: Instead of doing this, try to recover the ops used instead of the + // loop. + loops.push_back(nullptr); + } } return TiledLinalgOp{res, loops}; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" @@ -149,7 +150,8 @@ template <> void GenerateLoopNest::doit( ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn) { + function_ref bodyBuilderFn, + Optional) { SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); edsc::loopNestBuilder(lbs, ubs, steps, bodyBuilderFn); @@ -159,7 +161,8 @@ template <> void GenerateLoopNest::doit( ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn) { + function_ref bodyBuilderFn, + Optional) { SmallVector lbs, ubs, steps; unpackRanges(loopRanges, lbs, ubs, steps); @@ -175,12 +178,24 @@ edsc::affineLoopNestBuilder(lbs, ubs, constantSteps, bodyBuilderFn); } -/// Generates a loop nest consisting of scf.parallel and scf.for, depending on -/// the `iteratorTypes.` Consecutive parallel loops create a single scf.parallel -/// operation; each sequential loop creates a new scf.for operation. The body -/// of the innermost loop is populated by `bodyBuilderFn` that accepts a range -/// of induction variables for all loops. `ivStorage` is used to store the -/// partial list of induction variables. +/// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. +static void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc, + Value procId, Value nprocs, + Value &lb, Value &ub, + Value &step) { + using edsc::op::operator+; + using edsc::op::operator*; + lb = lb + (procId * step); + step = nprocs * step; +} + +/// Generates a loop nest consisting of scf.parallel and scf.for, depending +/// on the `iteratorTypes.` Consecutive parallel loops create a single +/// scf.parallel operation; each sequential loop creates a new scf.for +/// operation. The body of the innermost loop is populated by +/// `bodyBuilderFn` that accepts a range of induction variables for all +/// loops. `ivStorage` is used to store the partial list of induction +/// variables. // TODO: this function can be made iterative instead. However, it // will have at most as many recursive calls as nested loops, which rarely // exceeds 10. @@ -188,7 +203,8 @@ generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef iteratorTypes, function_ref bodyBuilderFn, - SmallVectorImpl &ivStorage) { + SmallVectorImpl &ivStorage, + ArrayRef distributionMethod = {}) { assert(lbs.size() == ubs.size()); assert(lbs.size() == steps.size()); assert(lbs.size() == iteratorTypes.size()); @@ -200,8 +216,8 @@ // Find the outermost parallel loops and drop their types from the list. unsigned nLoops = iteratorTypes.size(); - iteratorTypes = iteratorTypes.drop_while(isParallelIteratorType); - unsigned nOuterPar = nLoops - iteratorTypes.size(); + unsigned nOuterPar = + nLoops - iteratorTypes.drop_while(isParallelIteratorType).size(); // If there are no outer parallel loops, generate one sequential loop and // recurse. Note that we wouldn't have dropped anything from `iteratorTypes` @@ -211,41 +227,132 @@ ivStorage.push_back(iv); generateParallelLoopNest(lbs.drop_front(), ubs.drop_front(), steps.drop_front(), iteratorTypes.drop_front(), - bodyBuilderFn, ivStorage); + bodyBuilderFn, ivStorage, distributionMethod); }); return; } + if (distributionMethod.empty()) { + // Generate a single parallel loop-nest operation for all outermost + // parallel loops and recurse. + edsc::OperationBuilder( + lbs.take_front(nOuterPar), ubs.take_front(nOuterPar), + steps.take_front(nOuterPar), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { + edsc::ScopedContext context(nestedBuilder, nestedLoc); + ivStorage.append(localIvs.begin(), localIvs.end()); + generateParallelLoopNest( + lbs.drop_front(nOuterPar), ubs.drop_front(nOuterPar), + steps.drop_front(nOuterPar), iteratorTypes.drop_front(nOuterPar), + bodyBuilderFn, ivStorage, + (distributionMethod.size() < nOuterPar) + ? ArrayRef() + : distributionMethod.drop_front(nOuterPar)); + }); + return; + } + + // Process all consecutive similarly distributed loops simultaneously. + DistributionMethod methodToUse = distributionMethod[0]; + unsigned numProcessed = 1; + for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) { + if (distributionMethod[i] != methodToUse) + break; + numProcessed++; + } - // Generate a single parallel loop-nest operation for all outermost parallel - // loops and recurse. - edsc::OperationBuilder( - lbs.take_front(nOuterPar), ubs.take_front(nOuterPar), - steps.take_front(nOuterPar), - [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { - edsc::ScopedContext context(nestedBuilder, nestedLoc); - ivStorage.append(localIvs.begin(), localIvs.end()); - generateParallelLoopNest(lbs.drop_front(nOuterPar), - ubs.drop_front(nOuterPar), - steps.drop_front(nOuterPar), iteratorTypes, - bodyBuilderFn, ivStorage); - }); + switch (methodToUse) { + case DistributionMethod::Cyclic: { + // Generate a single parallel loop-nest operation for all outermost + // parallel loops and recurse. + edsc::OperationBuilder( + lbs.take_front(numProcessed), ubs.take_front(numProcessed), + steps.take_front(numProcessed), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { + edsc::ScopedContext context(nestedBuilder, nestedLoc); + ivStorage.append(localIvs.begin(), localIvs.end()); + generateParallelLoopNest( + lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), + steps.drop_front(numProcessed), + iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, + (distributionMethod.size() < numProcessed) + ? ArrayRef() + : distributionMethod.drop_front(numProcessed)); + }); + return; + } + case DistributionMethod::CyclicNumProcsGeNumIters: { + // Check (for the processed loops) that the iteration is in-bounds. + using edsc::op::slt; + using edsc::op::operator&&; + Value cond = slt(lbs[0], ubs[0]); + for (unsigned i = 1; i < numProcessed; ++i) + cond = cond && slt(lbs[i], ubs[i]); + ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); + edsc::conditionBuilder(cond, [&]() { + generateParallelLoopNest( + lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), + steps.drop_front(numProcessed), + iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, + distributionMethod.drop_front(numProcessed)); + }); + return; + } + case DistributionMethod::CyclicNumProcsEqNumIters: + // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed + // with inner loop generation. + ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); + generateParallelLoopNest( + lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), + steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), + bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed)); + return; + } } /// Specialization for generating a mix of parallel and sequential scf loops. template <> void GenerateLoopNest::doit( ArrayRef loopRanges, ArrayRef iteratorTypes, - function_ref bodyBuilderFn) { - SmallVector lbsStorage, ubsStorage, stepsStorage, ivs; - unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage); - ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); - + function_ref bodyBuilderFn, + Optional distributionOptions) { // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && "expected iterator type for all ranges"); iteratorTypes = iteratorTypes.take_front(loopRanges.size()); - ivs.reserve(iteratorTypes.size()); - generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs); + SmallVector lbsStorage, ubsStorage, stepsStorage, ivs; + unsigned numLoops = iteratorTypes.size(); + ivs.reserve(numLoops); + lbsStorage.reserve(numLoops); + ubsStorage.reserve(numLoops); + stepsStorage.reserve(numLoops); + + // Get the loop lb, ub, and step. + unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage); + + // Modify the lb, ub, and step based on the distribution options. + SmallVector distributionMethod; + if (distributionOptions) { + auto &options = distributionOptions.getValue(); + unsigned index = 0; + OpBuilder &builder = edsc::ScopedContext::getBuilderRef(); + Location loc = edsc::ScopedContext::getLocation(); + distributionMethod.assign(distributionOptions->distributionMethod.begin(), + distributionOptions->distributionMethod.end()); + for (auto iteratorType : enumerate(iteratorTypes)) + if (isParallelIteratorType(iteratorType.value()) && + index < distributionMethod.size()) { + unsigned i = iteratorType.index(); + ProcInfo procInfo = options.procInfo(builder, loc, index); + updateBoundsForCyclicDistribution(builder, loc, procInfo.procId, + procInfo.nprocs, lbsStorage[i], + ubsStorage[i], stepsStorage[i]); + index++; + } + } + ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); + generateParallelLoopNest(lbs, ubs, steps, iteratorTypes, bodyBuilderFn, ivs, + distributionMethod); + assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } diff --git a/mlir/test/Dialect/Linalg/tile-and-distribute.mlir b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/tile-and-distribute.mlir @@ -0,0 +1,168 @@ +// RUN: mlir-opt %s -test-linalg-transform-patterns=test-tile-and-distribute-options -split-input-file | FileCheck %s + +func @gemm1(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute1"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm1( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T1:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: scf.for %[[ARG3:.*]] = +// CHECK: %[[T3:.*]] = affine.apply #[[MAP0]]()[%[[T1]]] +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T3]], %[[ARG3]]] +// CHECK: %[[T11:.*]] = affine.apply #[[MAP0]]()[%[[T2]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T11]]] +// CHECK: %[[T15:.*]] = affine.apply #[[MAP0]]()[%[[T1]]] +// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T15]], %[[T18]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] + +// ----- + +func @gemm2(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute2"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm2( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T5]]] +// CHECK: %[[T7:.*]] = cmpi "slt", %[[T4]], %{{.*}} +// CHECK: %[[T8:.*]] = cmpi "slt", %[[T6]], %{{.*}} +// CHECK: %[[T9:.*]] = and %[[T7]], %[[T8]] +// CHECK: scf.if %[[T9]] +// CHECK: scf.for %[[ARG3:.*]] = +// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG3]]] +// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T5]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T18]]] +// CHECK: %[[T22:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T25:.*]] = affine.apply #[[MAP0]]()[%[[T5]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T22]], %[[T25]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] + +// ----- + +func @gemm3(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute3"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm3( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T4:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T4]]] +// CHECK: %[[T7:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: %[[T8:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK: %[[T9:.*]] = affine.apply #[[MAP0]]()[%[[T7]]] +// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T8]]] +// CHECK: scf.parallel (%[[ARG3:.*]], %[[ARG4:.*]]) = (%[[T5]], %[[T9]]) to (%{{.*}}, %{{.*}}) step (%[[T6]], %[[T10]]) +// CHECK: scf.for %[[ARG5:.*]] = +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG5]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG5]], %[[ARG4]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[ARG4]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] + +// ----- + +func @gemm4(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute4"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm4( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T5:.*]] = cmpi "slt", %[[T4]], %{{.*}} +// CHECK: scf.if %[[T5]] +// CHECK: scf.for %[[ARG3:.*]] = +// CHECK: %[[T6:.*]] = affine.apply #[[MAP0]]()[%[[T2]]] +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T6]], %[[ARG3]]] +// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG3]], %[[T14]]] +// CHECK: %[[T18:.*]] = affine.apply #[[MAP0]]()[%[[T2]]] +// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T18]], %[[T21]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] + +// ----- + +func @gemm5(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute5"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm5( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T3:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T5:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: %[[T6:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK: %[[T7:.*]] = affine.apply #[[MAP0]]()[%[[T5]]] +// CHECK: %[[T8:.*]] = affine.apply #[[MAP0]]()[%[[T6]]] +// CHECK: %[[T9:.*]] = cmpi "slt", %[[T4]], %{{.*}} +// CHECK: scf.if %[[T9]] +// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T7]]) to (%{{.*}}) step (%[[T8]]) +// CHECK: scf.for %[[ARG4:.*]] = +// CHECK: %[[T10:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[T10]], %[[ARG4]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[ARG3]]] +// CHECK: %[[T21:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[T21]], %[[ARG3]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] + +// ----- + +func @gemm6(%a : memref, %b : memref, %c : memref) +{ + linalg.matmul %a, %b, %c {__internal_linalg_transform__ = "distribute6"} + : (memref, memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 8)> +// CHECK: func @gemm6( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]: memref +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]*]]: memref +// CHECK: %[[T2:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK: %[[T3:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK: %[[T4:.*]] = affine.apply #[[MAP0]]()[%[[T2]]] +// CHECK: %[[T5:.*]] = affine.apply #[[MAP0]]()[%[[T3]]] +// CHECK: %[[T6:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK: scf.parallel (%[[ARG3.*]]) = (%[[T4]]) to (%{{.*}}) step (%[[T5]]) +// CHECK: scf.for %[[ARG4:.*]] = +// CHECK: %[[SV1:.*]] = subview %[[ARG0]][%[[ARG3]], %[[ARG4]]] +// CHECK: %[[T14:.*]] = affine.apply #[[MAP0]]()[%[[T6]]] +// CHECK: %[[SV2:.*]] = subview %[[ARG1]][%[[ARG4]], %[[T14]]] +// CHECK: %[[T20:.*]] = affine.apply #[[MAP0]]()[%[[T6]]] +// CHECK: %[[SV3:.*]] = subview %[[ARG2]][%[[ARG3]], %[[T20]]] +// CHECK: linalg.matmul %[[SV1]], %[[SV2]], %[[SV3]] diff --git a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp --- a/mlir/test/lib/Transforms/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Transforms/TestLinalgTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" @@ -49,6 +50,10 @@ Option testPromotionOptions{*this, "test-linalg-promotion-options", llvm::cl::desc("Test promotion options"), llvm::cl::init(false)}; + Option testTileAndDistributionOptions{ + *this, "test-tile-and-distribute-options", + llvm::cl::desc("Test tile and distribute options"), + llvm::cl::init(false)}; Option testVectorTransferForwardingPatterns{ *this, "test-vector-transfer-forwarding-patterns", llvm::cl::desc( @@ -143,6 +148,11 @@ /*loweringType=*/LinalgLoweringType::Loops, LinalgMarker(Identifier::get("REG", ctx))); + //===--------------------------------------------------------------------===// + // Linalg distribution patterns. + //===--------------------------------------------------------------------===// + LinalgLoopDistributionOptions distributionOptions; + //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// @@ -278,6 +288,122 @@ LinalgMarker(Identifier::get("PROMOTE", ctx))); } +template +static ProcInfo getGpuProcIds(OpBuilder &b, Location loc, unsigned loopNum) { + Type indexType = b.getIndexType(); + switch (loopNum) { + case 0: + return {b.create(loc, indexType, b.getStringAttr("y")), + b.create(loc, indexType, b.getStringAttr("y"))}; + case 1: + return {b.create(loc, indexType, b.getStringAttr("x")), + b.create(loc, indexType, b.getStringAttr("x"))}; + default: + llvm_unreachable("test patterns handles only upto 2-level nested loops"); + } + return {nullptr, nullptr}; +} + +static void fillTileAndDistributePatterns(MLIRContext *context, + OwningRewritePatternList &patterns) { + { + LinalgLoopDistributionOptions cyclicNprocsEqNiters; + cyclicNprocsEqNiters.distributionMethod.resize( + 2, DistributionMethod::CyclicNumProcsEqNumIters); + cyclicNprocsEqNiters.procInfo = + getGpuProcIds; + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsEqNiters), + LinalgMarker(Identifier::get("distribute1", context), + Identifier::get("after_distribute1", context))); + } + + { + LinalgLoopDistributionOptions cyclicNprocsGeNiters; + cyclicNprocsGeNiters.distributionMethod.resize( + 2, DistributionMethod::CyclicNumProcsGeNumIters); + cyclicNprocsGeNiters.procInfo = + getGpuProcIds; + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsGeNiters), + LinalgMarker(Identifier::get("distribute2", context), + Identifier::get("after_distribute2", context))); + } + + { + LinalgLoopDistributionOptions cyclicNprocsDefault; + cyclicNprocsDefault.distributionMethod.resize(2, + DistributionMethod::Cyclic); + cyclicNprocsDefault.procInfo = + getGpuProcIds; + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsDefault), + LinalgMarker(Identifier::get("distribute3", context), + Identifier::get("after_distribute3", context))); + } + + { + LinalgLoopDistributionOptions cyclicNprocsMixed1; + cyclicNprocsMixed1.distributionMethod = { + DistributionMethod::CyclicNumProcsEqNumIters, + DistributionMethod::CyclicNumProcsGeNumIters}; + cyclicNprocsMixed1.procInfo = getGpuProcIds; + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsMixed1), + LinalgMarker(Identifier::get("distribute4", context), + Identifier::get("after_distribute4", context))); + } + + { + LinalgLoopDistributionOptions cyclicNprocsMixed2; + cyclicNprocsMixed2.distributionMethod = { + DistributionMethod::CyclicNumProcsGeNumIters, + DistributionMethod::Cyclic}; + cyclicNprocsMixed2.procInfo = getGpuProcIds; + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsMixed2), + LinalgMarker(Identifier::get("distribute5", context), + Identifier::get("after_distribute5", context))); + } + + { + LinalgLoopDistributionOptions cyclicNprocsMixed3; + cyclicNprocsMixed3.distributionMethod = { + DistributionMethod::Cyclic, + DistributionMethod::CyclicNumProcsEqNumIters}; + cyclicNprocsMixed3.procInfo = getGpuProcIds; + + patterns.insert>( + context, + LinalgTilingOptions() + .setTileSizes({8, 8, 4}) + .setLoopType(LinalgTilingLoopType::ParallelLoops) + .setDistributionOptions(cyclicNprocsMixed3), + LinalgMarker(Identifier::get("distribute6", context), + Identifier::get("after_distribute6", context))); + } +} + static void applyMatmulToVectorPatterns(FuncOp funcOp, bool testMatmulToVectorPatterns1dTiling, @@ -344,6 +470,12 @@ applyPatternsAndFoldGreedily(getFunction(), patterns); return; } + if (testTileAndDistributionOptions) { + OwningRewritePatternList patterns; + fillTileAndDistributePatterns(&getContext(), patterns); + applyPatternsAndFoldGreedily(getFunction(), patterns); + return; + } if (testPatterns) return applyPatterns(getFunction()); if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)