diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -860,6 +860,13 @@ RewritePatternSet &patterns, LinalgTransformationFilter filter = LinalgTransformationFilter()); +/// Linalg distribution patterns +// +/// Populates `patterns` with patterns to distribute linalg.tiled_loop. +void populateLinalgDistributeTiledLoopPattern( + RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts, + const LinalgTransformationFilter &marker); + //===----------------------------------------------------------------------===// // Op-specific patterns. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -184,6 +184,8 @@ }; using ProcInfoCallBackFn = std::function( OpBuilder &b, Location loc, ArrayRef parallelLoopRanges)>; +using OneDimProcInfoCallBackFn = + std::function; /// Options that allow distribution of loops generated in Linalg transforms to /// processors while generating the loops. @@ -201,6 +203,11 @@ /// applied. If the vector is less than the number of `scf.parallel` loops /// generated, then no distribution is applied. SmallVector distributionMethod = {}; + + /// The map keyed by the distribution type that contains callback functions + /// that return the Values for processor ID (`procId`), and number of + /// processors (`nprocs`) used to execute the parallel loops. + DenseMap procInfoMap; }; /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`. diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ CodegenStrategy.cpp ComprehensiveBufferize.cpp Detensorize.cpp + Distribution.cpp DropUnitDims.cpp ElementwiseToLinalg.cpp Fusion.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/Distribution.cpp @@ -0,0 +1,85 @@ +//===- Distibution.cpp - linalg named ops to generic ops --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Linalg distibution pass. It updates `tiled_loop` +// control variables depending on the distribution type. +// +//===----------------------------------------------------------------------===// +// +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +#define DEBUG_TYPE "linalg-distribution" + +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") + +using namespace mlir; +using namespace mlir::linalg; + +namespace { + +struct DistributeTiledLoopPattern + : public OpRewritePattern { + DistributeTiledLoopPattern(MLIRContext *context, + LinalgLoopDistributionOptions options, + LinalgTransformationFilter marker) + : OpRewritePattern(context), options(options), + marker(marker) {} + LogicalResult matchAndRewrite(linalg::TiledLoopOp op, + PatternRewriter &rewriter) const override { + if (failed(marker.checkAndNotify(rewriter, op))) + return failure(); + if (!op.distribution_types().hasValue()) + return failure(); + + Location loc = op.getLoc(); + SmallVector newLowerBounds = op.lowerBound(); + SmallVector newUpperBounds = op.upperBound(); + SmallVector newSteps = op.step(); + + // Update bounds and steps. + auto distributionTypes = op.distribution_types().getValue(); + for (int i = 0, e = op.getNumLoops(); i < e; ++i) { + StringRef type = distributionTypes[i].cast().getValue(); + auto procInfoCallback = options.procInfoMap.find(type); + if (procInfoCallback == options.procInfoMap.end()) + continue; + + if (!isParallelIteratorType(op.iterator_types()[i])) { + op.emitOpError("only support for parallel loops is implemented"); + return failure(); + } + ProcInfo info = procInfoCallback->second(rewriter, loc); + updateBoundsForCyclicDistribution(rewriter, loc, info.procId, info.nprocs, + newLowerBounds[i], newUpperBounds[i], + newSteps[i]); + } + rewriter.updateRootInPlace(op, [&] { + op.setLowerBounds(newLowerBounds); + op.setUpperBounds(newUpperBounds); + op.setSteps(newSteps); + }); + marker.replaceLinalgTransformationFilter(rewriter, op); + return success(); + } + +private: + LinalgLoopDistributionOptions options; + LinalgTransformationFilter marker; +}; + +} // namespace + +void mlir::linalg::populateLinalgDistributeTiledLoopPattern( + RewritePatternSet &patterns, const LinalgLoopDistributionOptions &opts, + const LinalgTransformationFilter &marker) { + patterns.add(patterns.getContext(), opts, marker); +} diff --git a/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/distribute-tiled-loop.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -test-linalg-distribution %s | FileCheck %s + +func private @foo(%A: tensor<64x64xf32>, + %B: tensor<64x64xf32>) -> tensor<64x64xf32> + +func @distribute_for_gpu(%A: tensor<64x64xf32>, + %B: tensor<64x64xf32>) -> tensor<64x64xf32> { + %c0 = constant 0 : index + %c16 = constant 16 : index + %c64 = constant 64 : index + %c24 = constant 24 : index + %0 = linalg.tiled_loop (%i, %j) = (%c0, %c0) to (%c64, %c64) step (%c24, %c16) + ins (%A_ = %A: tensor<64x64xf32>) outs (%B_ = %B:tensor<64x64xf32>) + distribution ["block_x", "block_y"] { + %0 = call @foo(%A_, %B_) + : (tensor<64x64xf32>, tensor<64x64xf32>) -> tensor<64x64xf32> + linalg.yield %0 : tensor<64x64xf32> + } + return %0 : tensor<64x64xf32> +} + +// CHECK-DAG: #[[$MAP0:.+]] = affine_map<()[s0] -> (s0 * 24)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 * 16)> + +// CHECK-LABEL: func @distribute_for_gpu +// CHECK: %[[C64:.*]] = constant 64 : index + +// CHECK-DAG: %[[GPU_BLOCK_X:.*]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[GPU_GRID_DIM_X:.*]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[LB_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_BLOCK_X]]] +// CHECK-DAG: %[[STEP_I:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[GPU_GRID_DIM_X]]] + +// CHECK-DAG: %[[GPU_BLOCK_Y:.*]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[GPU_GRID_DIM_Y:.*]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[LB_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_BLOCK_Y]]] +// CHECK-DAG: %[[STEP_J:.*]] = affine.apply #[[$MAP1]](){{\[}}%[[GPU_GRID_DIM_Y]]] + +// CHECK: linalg.tiled_loop (%[[I:.*]], %[[J:.*]]) = (%[[LB_I]], %[[LB_J]]) +// CHECK-SAME: to (%[[C64]], %[[C64]]) step (%[[STEP_I]], %[[STEP_J]]) diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgDistribution.cpp @@ -0,0 +1,80 @@ +//===- TestLinalgDistribution.cpp - Test Linalg hoisting functions +//------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements logic for testing Linalg hoisting functions. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace mlir::linalg; + +template +static linalg::ProcInfo getGpuBlockInfo(OpBuilder &b, Location loc) { + std::string d(1, dim); + StringAttr attr = b.getStringAttr(d); + + Type indexType = b.getIndexType(); + ProcInfo procInfo = {b.create(loc, indexType, attr), + b.create(loc, indexType, attr)}; + return procInfo; +} + +static LinalgLoopDistributionOptions getDistributionOptions() { + LinalgLoopDistributionOptions opts; + opts.procInfoMap.insert(std::make_pair("block_x", getGpuBlockInfo<'x'>)); + opts.procInfoMap.insert(std::make_pair("block_y", getGpuBlockInfo<'y'>)); + return opts; +} + +namespace { +struct TestLinalgDistribution + : public PassWrapper { + TestLinalgDistribution() = default; + TestLinalgDistribution(const TestLinalgDistribution &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnFunction() override; +}; +} // namespace + +void TestLinalgDistribution::runOnFunction() { + auto funcOp = getFunction(); + OwningRewritePatternList distributeTiledLoopsPatterns(&getContext()); + populateLinalgDistributeTiledLoopPattern( + distributeTiledLoopsPatterns, getDistributionOptions(), + LinalgTransformationFilter( + ArrayRef{}, + {Identifier::get("distributed", funcOp.getContext())}) + .addFilter([](Operation *op) { + return success(!op->getParentOfType()); + })); + (void)applyPatternsAndFoldGreedily(funcOp, + std::move(distributeTiledLoopsPatterns)); + // Ensure we drop the marker in the end. + funcOp.walk([](LinalgOp op) { + op->removeAttr(LinalgTransforms::kLinalgTransformMarker); + }); +} + +namespace mlir { +namespace test { +void registerTestLinalgDistribution() { + PassRegistration testTestLinalgDistributionPass( + "test-linalg-distribution", "Test Linalg distribution."); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -77,6 +77,7 @@ void registerTestIRVisitorsPass(); void registerTestInterfaces(); void registerTestLinalgCodegenStrategy(); +void registerTestLinalgDistribution(); void registerTestLinalgElementwiseFusion(); void registerTestPushExpandingReshape(); void registerTestLinalgFusionTransforms(); @@ -156,6 +157,7 @@ test::registerTestIRVisitorsPass(); test::registerTestInterfaces(); test::registerTestLinalgCodegenStrategy(); + test::registerTestLinalgDistribution(); test::registerTestLinalgElementwiseFusion(); test::registerTestPushExpandingReshape(); test::registerTestLinalgFusionTransforms();