diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h --- a/mlir/include/mlir/Dialect/LoopOps/Passes.h +++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h @@ -22,6 +22,9 @@ /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); +/// Creates a pass which tiles innermost parallel loops. +std::unique_ptr createParallelLoopTilingPass(int64_t tileSize = 4); + } // namespace mlir #endif // MLIR_DIALECT_LOOPOPS_PASSES_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -109,6 +109,7 @@ // LoopOps createParallelLoopFusionPass(); + createParallelLoopTilingPass(); // QuantOps quant::createConvertSimulatedQuantPass(); diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRLoopOpsTransforms ParallelLoopFusion.cpp + ParallelLoopTiling.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LoopOps diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp @@ -0,0 +1,120 @@ +//===- ParallelLoopTiling.cpp - Tiles loop.parallel ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop tiling on parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; +using loop::ParallelOp; + +/// Tile a parallel loop of the form +/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4, %arg5) +/// +/// into +/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4*tileSize, +/// %arg5*tileSize) +/// loop.parallel (%i0, %i1) = (0, 0) to (min(%arg2, %arg0), +/// min(%arg3, %arg1)) +/// step (%arg4, %arg5) +/// The old loop is replaced with the new one. +static void tileParallelLoop(ParallelOp op, int64_t tileSize) { + OpBuilder b(op); + auto zero = b.create(op.getLoc(), 0); + auto tileSizeConstant = b.create(op.getLoc(), tileSize); + + // Create the outer loop with adjusted steps. + SmallVector newSteps; + newSteps.reserve(op.step().size()); + for (Value step : op.step()) { + newSteps.push_back(b.create(op.getLoc(), step, tileSizeConstant)); + } + auto outerLoop = b.create(op.getLoc(), op.lowerBound(), + op.upperBound(), newSteps); + b.setInsertionPointToStart(outerLoop.getBody()); + + // Compute min(size, dim - offset) to avoid out-of-bounds accesses. + // FIXME: Instead of using min, we want to replicate the tail. This would give + // the inner loop constant bounds for easy vectorization. + auto minMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {getAffineDimExpr(/*position=*/0, b.getContext()), + getAffineDimExpr(/*position=*/1, b.getContext()) - + getAffineDimExpr(/*position=*/2, b.getContext())}); + + // Create the inner loop with adjusted bounds. + SmallVector newBounds; + newBounds.reserve(op.upperBound().size()); + for (auto bounds : llvm::zip(op.upperBound(), op.lowerBound())) { + newBounds.push_back( + b.create(op.getLoc(), b.getIndexType(), minMap, + ValueRange{tileSizeConstant, std::get<0>(bounds), + std::get<1>(bounds)})); + } + auto innerLoop = b.create( + op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, + op.step()); + + // Steal the body of the old parallel loop and erase it. + innerLoop.region().takeBody(op.region()); + op.erase(); +} + +/// Get a list of most nested parallel loops. Assumes that ParallelOps are only +/// directly nested. +static bool getInnermostNestedLoops(Block *block, + SmallVectorImpl &loops) { + bool hasInnerLoop = false; + for (auto parallelOp : block->getOps()) { + hasInnerLoop = true; + if (!getInnermostNestedLoops(parallelOp.getBody(), loops)) { + loops.push_back(parallelOp); + } + } + return hasInnerLoop; +} + +namespace { +struct ParallelLoopTiling : public FunctionPass { + ParallelLoopTiling() = default; + ParallelLoopTiling(const ParallelLoopTiling &) {} // tileSize is non-copyable. + explicit ParallelLoopTiling(int64_t tileSize) { this->tileSize = tileSize; } + + void runOnFunction() override { + SmallVector mostNestedParallelOps; + for (Block &block : getFunction()) { + getInnermostNestedLoops(&block, mostNestedParallelOps); + } + for (ParallelOp pLoop : mostNestedParallelOps) { + tileParallelLoop(pLoop, tileSize); + } + } + + Option tileSize{ + *this, "parallel-loop-tile-size", + llvm::cl::desc("factor to tile innermost parallel loops by"), + llvm::cl::init(4)}; +}; +} // namespace + +std::unique_ptr mlir::createParallelLoopTilingPass(int64_t tileSize) { + return std::make_unique(tileSize); +} + +static PassRegistration pass("parallel-loop-tiling", + "Tile parallel loops."); diff --git a/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir b/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir @@ -0,0 +1,85 @@ +// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-tiling)' -split-input-file | FileCheck %s --dump-input-on-failure + +func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, + %A: memref, %B: memref, + %C: memref, %result: memref) { + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) step (%arg4, %arg5) { + %B_elem = load %B[%i0, %i1] : memref + %C_elem = load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> + +// CHECK-LABEL: func @parallel_loop( +// CHECK-SAME: [[VAL_0:%.*]]: index, [[VAL_1:%.*]]: index, [[VAL_2:%.*]]: index, [[VAL_3:%.*]]: index, [[VAL_4:%.*]]: index, [[VAL_5:%.*]]: index, [[VAL_6:%.*]]: memref, [[VAL_7:%.*]]: memref, [[VAL_8:%.*]]: memref, [[VAL_9:%.*]]: memref) { +// CHECK: [[VAL_10:%.*]] = constant 0 : index +// CHECK: [[VAL_11:%.*]] = constant 4 : index +// CHECK: [[VAL_12:%.*]] = muli [[VAL_4]], [[VAL_11]] : index +// CHECK: [[VAL_13:%.*]] = muli [[VAL_5]], [[VAL_11]] : index +// CHECK: loop.parallel ([[VAL_14:%.*]], [[VAL_15:%.*]]) = ([[VAL_0]], [[VAL_1]]) to ([[VAL_2]], [[VAL_3]]) step ([[VAL_12]], [[VAL_13]]) { +// CHECK: [[VAL_16:%.*]] = affine.min #map0([[VAL_11]], [[VAL_2]], [[VAL_0]]) +// CHECK: [[VAL_17:%.*]] = affine.min #map0([[VAL_11]], [[VAL_3]], [[VAL_1]]) +// CHECK: loop.parallel ([[VAL_18:%.*]], [[VAL_19:%.*]]) = ([[VAL_10]], [[VAL_10]]) to ([[VAL_16]], [[VAL_17]]) step ([[VAL_4]], [[VAL_5]]) { +// CHECK: [[VAL_20:%.*]] = load [[VAL_7]]{{\[}}[[VAL_18]], [[VAL_19]]] : memref +// CHECK: [[VAL_21:%.*]] = load [[VAL_8]]{{\[}}[[VAL_18]], [[VAL_19]]] : memref +// CHECK: [[VAL_22:%.*]] = addf [[VAL_20]], [[VAL_21]] : f32 +// CHECK: store [[VAL_22]], [[VAL_9]]{{\[}}[[VAL_18]], [[VAL_19]]] : memref +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: return +// CHECK: } + + +func @tile_nested_innermost() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + return +} + +// CHECK-LABEL: func @tile_nested_innermost() { +// CHECK: [[VAL_23:%.*]] = constant 2 : index +// CHECK: [[VAL_24:%.*]] = constant 0 : index +// CHECK: [[VAL_25:%.*]] = constant 1 : index +// CHECK: loop.parallel ([[VAL_26:%.*]], [[VAL_27:%.*]]) = ([[VAL_24]], [[VAL_24]]) to ([[VAL_23]], [[VAL_23]]) step ([[VAL_25]], [[VAL_25]]) { +// CHECK: [[VAL_28:%.*]] = constant 0 : index +// CHECK: [[VAL_29:%.*]] = constant 4 : index +// CHECK: [[VAL_30:%.*]] = muli [[VAL_25]], [[VAL_29]] : index +// CHECK: [[VAL_31:%.*]] = muli [[VAL_25]], [[VAL_29]] : index +// CHECK: loop.parallel ([[VAL_32:%.*]], [[VAL_33:%.*]]) = ([[VAL_24]], [[VAL_24]]) to ([[VAL_23]], [[VAL_23]]) step ([[VAL_30]], [[VAL_31]]) { +// CHECK: [[VAL_34:%.*]] = affine.min #map0([[VAL_29]], [[VAL_23]], [[VAL_24]]) +// CHECK: [[VAL_35:%.*]] = affine.min #map0([[VAL_29]], [[VAL_23]], [[VAL_24]]) +// CHECK: loop.parallel ([[VAL_36:%.*]], [[VAL_37:%.*]]) = ([[VAL_28]], [[VAL_28]]) to ([[VAL_34]], [[VAL_35]]) step ([[VAL_25]], [[VAL_25]]) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: [[VAL_38:%.*]] = constant 0 : index +// CHECK: [[VAL_39:%.*]] = constant 4 : index +// CHECK: [[VAL_40:%.*]] = muli [[VAL_25]], [[VAL_39]] : index +// CHECK: [[VAL_41:%.*]] = muli [[VAL_25]], [[VAL_39]] : index +// CHECK: loop.parallel ([[VAL_42:%.*]], [[VAL_43:%.*]]) = ([[VAL_24]], [[VAL_24]]) to ([[VAL_23]], [[VAL_23]]) step ([[VAL_40]], [[VAL_41]]) { +// CHECK: [[VAL_44:%.*]] = affine.min #map0([[VAL_39]], [[VAL_23]], [[VAL_24]]) +// CHECK: [[VAL_45:%.*]] = affine.min #map0([[VAL_39]], [[VAL_23]], [[VAL_24]]) +// CHECK: loop.parallel ([[VAL_46:%.*]], [[VAL_47:%.*]]) = ([[VAL_38]], [[VAL_38]]) to ([[VAL_44]], [[VAL_45]]) step ([[VAL_25]], [[VAL_25]]) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: return +// CHECK: }