diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -369,7 +369,10 @@ /// to /// /// %iv = %lb + %procId * %step - CyclicNumProcsEqNumIters = 2 + CyclicNumProcsEqNumIters = 2, + + /// No Distribution. + None = 3 }; /// Callback function type used to get processor ID, and number of processors @@ -377,11 +380,10 @@ struct ProcInfo { Value procId; Value nprocs; + DistributionMethod distributionMethod; }; -using ProcInfoCallBackFn = std::function( +using ProcInfoCallBackFn = std::function( OpBuilder &b, Location loc, ArrayRef parallelLoopRanges)>; -using OneDimProcInfoCallBackFn = - std::function; /// Options that allow distribution of loops generated in Linalg transforms to /// processors while generating the loops. @@ -389,21 +391,10 @@ /// Callback function that returns the Values for processor ID (`procId`), and /// number of processors (`nprocs`) used to execute the parallel loops. The /// number of `{procId, nprocs}` pairs returned must be equal to the number of - /// `parallelLoopRanges` passed into the callback, which in-turn is same as - /// the number of parallel loops for which the `distributionMethod` is - /// specified below. + /// `parallelLoopRanges` passed into the callback. The `parallelLoopRanges` + /// are ranges of the outer parallel loops of the operation that + /// do have non-zero tile sizes specified. ProcInfoCallBackFn procInfo; - /// Specification of how to distribute the `scf.parallel` loops that are - /// generated. As the `scf.parallel` loop is generated, the elements of this - /// vector is used (from left to right) and the specified distribution is - /// applied. If the vector is less than the number of `scf.parallel` loops - /// generated, then no distribution is applied. - SmallVector distributionMethod = {}; - - /// The map keyed by the distribution type that contains callback functions - /// that return the Values for processor ID (`procId`), and number of - /// processors (`nprocs`) used to execute the parallel loops. - DenseMap procInfoMap; }; /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. @@ -521,8 +512,7 @@ function_ref bodyBuilderFn, - Optional = None, - ArrayRef distributionTypes = {}); + ArrayRef procInfo = {}); }; } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -450,6 +450,31 @@ applyPermutationToVector(iteratorTypes, permutation); } + // Handle distribution. Create a vector of the same size of loops that are to + // be tiled. + SmallVector procInfo; + if (options.distribution) { + procInfo.resize( + iteratorTypes.size(), + linalg::ProcInfo{nullptr, nullptr, linalg::DistributionMethod::None}); + // Collect loop ranges of tiled loopss, loops that are parallel. + SmallVector parallelLoopRanges; + for (auto iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); + } + auto returnedProcInfo = + options.distribution->procInfo(b, op.getLoc(), parallelLoopRanges); + unsigned procIdIdx = 0; + // Update the distribution information for the loops. + for (auto iteratorType : llvm::enumerate(iteratorTypes)) { + if (!isParallelIterator(iteratorType.value())) + break; + procInfo[iteratorType.index()] = returnedProcInfo[procIdIdx++]; + } + } + // 2. Create the tiled loops. LinalgOp res = op; SmallVector ivs, tensorResults; @@ -489,8 +514,7 @@ return scf::ValueVector(tensorResults.begin(), tensorResults.end()); }; GenerateLoopNest::doit(b, op.getLoc(), loopRanges, op, iteratorTypes, - tiledLoopBodyBuilder, options.distribution, - options.distributionTypes); + tiledLoopBodyBuilder, procInfo); // 3. Transform IndexOp results w.r.t. the tiling. transformIndexOps(b, res, ivs, loopIndexToRangeIndex); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -518,25 +518,11 @@ function_ref bodyBuilderFn, - Optional distributionOptions, - ArrayRef distributionTypes) { + ArrayRef procInfo) { + assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && + "expected as many entries for proc info as number of loops, even if " + "they are null entries"); SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); - // Create procInfo so it dominates loops, if appropriate. - SmallVector procInfo; - SmallVector distributionMethod; - if (distributionOptions) { - // Collect loop ranges for parallel dimensions. - SmallVector parallelLoopRanges; - for (const auto &iteratorType : enumerate(iteratorTypes)) - if (isParallelIterator(iteratorType.value())) - parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); - - // Get their distribution schemes. - distributionMethod = distributionOptions->distributionMethod; - if (distributionMethod.size() < parallelLoopRanges.size()) - parallelLoopRanges.resize(distributionMethod.size()); - procInfo = distributionOptions->procInfo(b, loc, parallelLoopRanges); - } SmallVector lbs, ubs, steps; unpackRanges(b, loc, loopRanges, lbs, ubs, steps); @@ -554,20 +540,17 @@ return bodyBuilderFn(b, loc, ivs, operandValuesToUse); }); - if (!distributionOptions || loopNest.loops.empty()) + if (loopNest.loops.empty() || procInfo.empty()) return; // Filter out scf.for loops that were created out of parallel dimensions. - SmallVector loops; - for (const auto &iteratorType : enumerate(iteratorTypes)) - if (isParallelIterator(iteratorType.value())) - loops.push_back(loopNest.loops[iteratorType.index()]); - - // Distribute - only supports cyclic distribution for now. - for (auto it : llvm::zip(loops, procInfo, distributionMethod)) - if (std::get<2>(it) == DistributionMethod::Cyclic) - mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId, - std::get<1>(it).nprocs); + for (auto loop : llvm::enumerate(loopNest.loops)) { + if (procInfo[loop.index()].distributionMethod == + DistributionMethod::Cyclic) { + mapLoopToProcessorIds(loop.value(), procInfo[loop.index()].procId, + procInfo[loop.index()].nprocs); + } + } } /// Specialization to build affine "for" nest. @@ -578,7 +561,7 @@ function_ref bodyBuilderFn, - Optional, ArrayRef) { + ArrayRef /*procInfo*/) { SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected AffineForOp init values"); SmallVector lbs, ubs, steps; @@ -625,12 +608,13 @@ static void generateParallelLoopNest( OpBuilder &b, Location loc, ValueRange lbs, ValueRange ubs, ValueRange steps, ArrayRef iteratorTypes, + ArrayRef procInfo, function_ref bodyBuilderFn, - SmallVectorImpl &ivStorage, - ArrayRef distributionMethod = {}) { + SmallVectorImpl &ivStorage) { assert(lbs.size() == ubs.size()); assert(lbs.size() == steps.size()); assert(lbs.size() == iteratorTypes.size()); + assert(procInfo.empty() || (lbs.size() == procInfo.size())); // If there are no (more) loops to be generated, generate the body and be // done with it. @@ -639,55 +623,56 @@ return; } - // Find the outermost parallel loops and drop their types from the list. - unsigned nLoops = iteratorTypes.size(); - unsigned nOuterPar = - nLoops - iteratorTypes.drop_while(isParallelIterator).size(); - // If there are no outer parallel loops, generate one sequential loop and - // recurse. Note that we wouldn't have dropped anything from `iteratorTypes` - // in this case. - if (nOuterPar == 0) { + // recurse. + if (!isParallelIterator(iteratorTypes.front())) { LoopNest singleLoop = buildLoopNest( b, loc, lbs.take_front(), ubs.take_front(), steps.take_front(), [&](OpBuilder &b, Location loc, ValueRange ivs) { ivStorage.append(ivs.begin(), ivs.end()); - generateParallelLoopNest(b, loc, lbs.drop_front(), ubs.drop_front(), - steps.drop_front(), - iteratorTypes.drop_front(), bodyBuilderFn, - ivStorage, distributionMethod); + generateParallelLoopNest( + b, loc, lbs.drop_front(), ubs.drop_front(), steps.drop_front(), + iteratorTypes.drop_front(), + procInfo.empty() ? procInfo : procInfo.drop_front(), + bodyBuilderFn, ivStorage); }); return; } - if (distributionMethod.empty()) { + + unsigned nLoops = iteratorTypes.size(); + unsigned numProcessed = 0; + DistributionMethod distributionMethod = DistributionMethod::None; + if (procInfo.empty()) { + numProcessed = nLoops - iteratorTypes.drop_while(isParallelIterator).size(); + } else { + distributionMethod = procInfo.front().distributionMethod; + numProcessed = + nLoops - procInfo + .drop_while([&](linalg::ProcInfo p) { + return p.distributionMethod == distributionMethod; + }) + .size(); + } + + auto remainderProcInfo = + procInfo.empty() ? procInfo : procInfo.drop_front(numProcessed); + switch (distributionMethod) { + case DistributionMethod::None: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. b.create( - loc, lbs.take_front(nOuterPar), ubs.take_front(nOuterPar), - steps.take_front(nOuterPar), + loc, lbs.take_front(numProcessed), ubs.take_front(numProcessed), + steps.take_front(numProcessed), [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) { ivStorage.append(localIvs.begin(), localIvs.end()); generateParallelLoopNest( - nestedBuilder, nestedLoc, lbs.drop_front(nOuterPar), - ubs.drop_front(nOuterPar), steps.drop_front(nOuterPar), - iteratorTypes.drop_front(nOuterPar), bodyBuilderFn, ivStorage, - (distributionMethod.size() < nOuterPar) - ? ArrayRef() - : distributionMethod.drop_front(nOuterPar)); + nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), + ubs.drop_front(numProcessed), steps.drop_front(numProcessed), + iteratorTypes.drop_front(numProcessed), remainderProcInfo, + bodyBuilderFn, ivStorage); }); return; } - - // Process all consecutive similarly distributed loops simultaneously. - DistributionMethod methodToUse = distributionMethod[0]; - unsigned numProcessed = 1; - for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) { - if (distributionMethod[i] != methodToUse) - break; - numProcessed++; - } - - switch (methodToUse) { case DistributionMethod::Cyclic: { // Generate a single parallel loop-nest operation for all outermost // parallel loops and recurse. @@ -699,10 +684,8 @@ generateParallelLoopNest( nestedBuilder, nestedLoc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), steps.drop_front(numProcessed), - iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, - (distributionMethod.size() < numProcessed) - ? ArrayRef() - : distributionMethod.drop_front(numProcessed)); + iteratorTypes.drop_front(numProcessed), remainderProcInfo, + bodyBuilderFn, ivStorage); }); return; } @@ -714,11 +697,11 @@ cond = ab._and(cond, ab.slt(lbs[i], ubs[i])); ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed)); b.create(loc, cond, [&](OpBuilder &b, Location loc) { - generateParallelLoopNest( - b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), - steps.drop_front(numProcessed), - iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage, - distributionMethod.drop_front(numProcessed)); + generateParallelLoopNest(b, loc, lbs.drop_front(numProcessed), + ubs.drop_front(numProcessed), + steps.drop_front(numProcessed), + iteratorTypes.drop_front(numProcessed), + remainderProcInfo, bodyBuilderFn, ivStorage); b.create(loc, ValueRange{}); }); return; @@ -730,7 +713,7 @@ generateParallelLoopNest( b, loc, lbs.drop_front(numProcessed), ubs.drop_front(numProcessed), steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed), - bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed)); + remainderProcInfo, bodyBuilderFn, ivStorage); return; } } @@ -743,13 +726,14 @@ function_ref bodyBuilderFn, - Optional distributionOptions, - ArrayRef distributionTypes) { + ArrayRef procInfo) { SmallVector iterArgInitValues = linalgOp.getOutputTensorOperands(); assert(iterArgInitValues.empty() && "unexpected ParallelOp init values"); // This function may be passed more iterator types than ranges. assert(iteratorTypes.size() >= loopRanges.size() && "expected iterator type for all ranges"); + assert((procInfo.empty() || (procInfo.size() == loopRanges.size())) && + "expected proc information for all loops when present"); iteratorTypes = iteratorTypes.take_front(loopRanges.size()); SmallVector lbsStorage, ubsStorage, stepsStorage, ivs; unsigned numLoops = iteratorTypes.size(); @@ -762,42 +746,22 @@ unpackRanges(b, loc, loopRanges, lbsStorage, ubsStorage, stepsStorage); // Modify the lb, ub, and step based on the distribution options. - SmallVector distributionMethod; - if (distributionOptions) { - auto &options = *distributionOptions; - distributionMethod.assign(distributionOptions->distributionMethod.begin(), - distributionOptions->distributionMethod.end()); - SmallVector parallelLoopRanges; - for (const auto &iteratorType : enumerate(iteratorTypes)) { - if (isParallelIterator(iteratorType.value())) - parallelLoopRanges.push_back(loopRanges[iteratorType.index()]); - } - if (distributionMethod.size() < parallelLoopRanges.size()) - parallelLoopRanges.resize(distributionMethod.size()); - SmallVector procInfo = - options.procInfo(b, loc, parallelLoopRanges); - unsigned index = 0; - for (const auto &iteratorType : enumerate(iteratorTypes)) { - if (index >= procInfo.size()) - break; - if (isParallelIterator(iteratorType.value())) { - unsigned i = iteratorType.index(); - updateBoundsForCyclicDistribution(b, loc, procInfo[index].procId, - procInfo[index].nprocs, lbsStorage[i], - ubsStorage[i], stepsStorage[i]); - index++; - } + for (auto it : llvm::enumerate(procInfo)) { + if (it.value().distributionMethod != linalg::DistributionMethod::None) { + updateBoundsForCyclicDistribution( + b, loc, it.value().procId, it.value().nprocs, lbsStorage[it.index()], + ubsStorage[it.index()], stepsStorage[it.index()]); } } ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage); generateParallelLoopNest( - b, loc, lbs, ubs, steps, iteratorTypes, + b, loc, lbs, ubs, steps, iteratorTypes, procInfo, [&](OpBuilder &b, Location loc, ValueRange ivs) { SmallVector operandValuesToUse = linalgOp.getInputAndOutputOperands(); bodyBuilderFn(b, loc, ivs, operandValuesToUse); }, - ivs, distributionMethod); + ivs); assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops"); } diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -249,14 +249,16 @@ template static SmallVector -getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges) { +getGpuProcIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges, + ArrayRef distributionMethod) { size_t count = std::min(3, parallelLoopRanges.size()); SmallVector procInfo(count); Type indexType = b.getIndexType(); for (unsigned i = 0; i < count; ++i) { gpu::Dimension dim = *gpu::symbolizeDimension(i); procInfo[count - 1 - i] = {b.create(loc, indexType, dim), - b.create(loc, indexType, dim)}; + b.create(loc, indexType, dim), + distributionMethod[count - 1 - i]}; } return procInfo; } @@ -265,10 +267,15 @@ RewritePatternSet &patterns) { { LinalgLoopDistributionOptions cyclicNprocsEqNiters; - cyclicNprocsEqNiters.distributionMethod.resize( - 2, DistributionMethod::CyclicNumProcsEqNumIters); + SmallVector distributionMethod = { + DistributionMethod::CyclicNumProcsEqNumIters, + DistributionMethod::CyclicNumProcsEqNumIters}; cyclicNprocsEqNiters.procInfo = - getGpuProcIds; + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -282,10 +289,15 @@ { LinalgLoopDistributionOptions cyclicNprocsGeNiters; - cyclicNprocsGeNiters.distributionMethod.resize( - 2, DistributionMethod::CyclicNumProcsGeNumIters); + SmallVector distributionMethod = { + DistributionMethod::CyclicNumProcsGeNumIters, + DistributionMethod::CyclicNumProcsGeNumIters}; cyclicNprocsGeNiters.procInfo = - getGpuProcIds; + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -299,10 +311,14 @@ { LinalgLoopDistributionOptions cyclicNprocsDefault; - cyclicNprocsDefault.distributionMethod.resize(2, - DistributionMethod::Cyclic); + SmallVector distributionMethod = { + DistributionMethod::Cyclic, DistributionMethod::Cyclic}; cyclicNprocsDefault.procInfo = - getGpuProcIds; + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -316,10 +332,15 @@ { LinalgLoopDistributionOptions cyclicNprocsMixed1; - cyclicNprocsMixed1.distributionMethod = { + SmallVector distributionMethod = { DistributionMethod::CyclicNumProcsEqNumIters, DistributionMethod::CyclicNumProcsGeNumIters}; - cyclicNprocsMixed1.procInfo = getGpuProcIds; + cyclicNprocsMixed1.procInfo = + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -333,10 +354,15 @@ { LinalgLoopDistributionOptions cyclicNprocsMixed2; - cyclicNprocsMixed2.distributionMethod = { + SmallVector distributionMethod = { DistributionMethod::CyclicNumProcsGeNumIters, DistributionMethod::Cyclic}; - cyclicNprocsMixed2.procInfo = getGpuProcIds; + cyclicNprocsMixed2.procInfo = + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -350,10 +376,15 @@ { LinalgLoopDistributionOptions cyclicNprocsMixed3; - cyclicNprocsMixed3.distributionMethod = { + SmallVector distributionMethod = { DistributionMethod::Cyclic, DistributionMethod::CyclicNumProcsEqNumIters}; - cyclicNprocsMixed3.procInfo = getGpuProcIds; + cyclicNprocsMixed3.procInfo = + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, @@ -368,10 +399,14 @@ { LinalgLoopDistributionOptions cyclicNprocsEqNiters; - cyclicNprocsEqNiters.distributionMethod.resize(2, - DistributionMethod::Cyclic); + SmallVector distributionMethod = { + DistributionMethod::Cyclic, DistributionMethod::Cyclic}; cyclicNprocsEqNiters.procInfo = - getGpuProcIds; + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingOptions() @@ -387,8 +422,14 @@ static void fillTileFuseAndDistributePatterns(MLIRContext *context, RewritePatternSet &patterns) { LinalgLoopDistributionOptions cyclicNprocsEqNiters; - cyclicNprocsEqNiters.distributionMethod.resize(2, DistributionMethod::Cyclic); - cyclicNprocsEqNiters.procInfo = getGpuProcIds; + SmallVector distributionMethod = { + DistributionMethod::Cyclic, DistributionMethod::Cyclic}; + cyclicNprocsEqNiters.procInfo = + [distributionMethod](OpBuilder &b, Location loc, + ArrayRef parallelLoopRanges) { + return getGpuProcIds( + b, loc, parallelLoopRanges, distributionMethod); + }; patterns.add( MatmulOp::getOperationName(), context, LinalgTilingAndFusionOptions()