diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h --- a/mlir/include/mlir/Dialect/LoopOps/Passes.h +++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h @@ -22,6 +22,9 @@ /// Creates a loop fusion pass which fuses parallel loops. std::unique_ptr createParallelLoopFusionPass(); +/// Creates a pass which tiles innermost parallel loops. +std::unique_ptr createParallelLoopTilingPass(int64_t tileSize = 4); + } // namespace mlir #endif // MLIR_DIALECT_LOOPOPS_PASSES_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -109,6 +109,7 @@ // LoopOps createParallelLoopFusionPass(); + createParallelLoopTilingPass(); // QuantOps quant::createConvertSimulatedQuantPass(); diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt @@ -1,5 +1,6 @@ add_llvm_library(MLIRLoopOpsTransforms ParallelLoopFusion.cpp + ParallelLoopTiling.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LoopOps diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp @@ -0,0 +1,129 @@ +//===- ParallelLoopTiling.cpp - Tiles loop.parallel ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop tiling on parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/RegionUtils.h" +#include "llvm/Support/CommandLine.h" + +using namespace mlir; +using loop::ParallelOp; + +static llvm::cl::opt + clTileSize("parallel-loop-tile-size", + llvm::cl::desc("factor to tile innermost parallel loops by"), + llvm::cl::init(4)); + +// Tile a ploop of the form +// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +// step (%arg4, %arg5) +// +// into +// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +// step (%arg4*tileSize, +// %arg5*tileSize) +// loop.parallel (%i0, %i1) = (0, 0) to (min(%arg2, %arg0), +// min(%arg3, %arg1)) +// step (%arg4, %arg5) +// The old loops is replaced with the new one. +static void tilePLoop(ParallelOp op, int64_t tileSize) { + OpBuilder b(op); + SmallVector newSteps; + auto zero = b.create(op.getLoc(), 0); + auto tileSizeConstant = b.create(op.getLoc(), tileSize); + + // Create the outer loop with adjusted steps. + for (Value step : op.step()) { + newSteps.push_back(b.create(op.getLoc(), step, tileSizeConstant)); + } + auto outerLoop = b.create(op.getLoc(), op.lowerBound(), + op.upperBound(), newSteps); + b.setInsertionPointToStart(outerLoop.getBody()); + + // Compute min(size, dim - offset) to avoid out-of-bounds accesses. + // FIXME: Instead of using min, we want to replicate the tail. This would give + // the inner loop constant bounds for easy vectorization. + auto minMap = AffineMap::get( + /*dimCount=*/3, /*symbolCount=*/0, + {getAffineDimExpr(/*position=*/0, b.getContext()), + getAffineDimExpr(/*position=*/1, b.getContext()) - + getAffineDimExpr(/*position=*/2, b.getContext())}); + + // Create the inner loop with adjusted bounds. + SmallVector newBounds; + for (auto bounds : llvm::zip(op.upperBound(), op.lowerBound())) { + newBounds.push_back(b.create( + op.getLoc(), b.getIndexType(), minMap, + ValueRange{zero, std::get<0>(bounds), std::get<1>(bounds)})); + } + auto innerLoop = b.create( + op.getLoc(), SmallVector(newBounds.size(), zero), newBounds, + op.step()); + + // Steal the body of the old ploop. + innerLoop.getBody()->getOperations().splice( + innerLoop.getBody()->getOperations().begin(), + op.getBody()->getOperations(), op.getBody()->begin(), + std::prev(op.getBody()->end())); + for (auto use : + llvm::zip(op.getInductionVars(), innerLoop.getInductionVars())) { + replaceAllUsesInRegionWith(std::get<0>(use), std::get<1>(use), + innerLoop.region()); + } + + // Erase the old ploop. + op.erase(); +} + +// Get a list of most nested ploops. Assumes that ParallelOps are only directly +// nested. +static bool getInnermostNestedLoops(Block *block, + SmallVectorImpl &loops) { + for (auto &op : *block) { + if (auto parallelOp = dyn_cast(op)) { + if (getInnermostNestedLoops(parallelOp.getBody(), loops)) { + loops.push_back(parallelOp); + } + return false; + } + } + return true; +} + +namespace { +struct ParallelLoopTiling : public FunctionPass { + explicit ParallelLoopTiling(int64_t tileSize) : tileSize(tileSize) {} + + void runOnFunction() override { + SmallVector mostNestedParallelOps; + for (Block &block : getFunction()) { + getInnermostNestedLoops(&block, mostNestedParallelOps); + } + for (ParallelOp pLoop : mostNestedParallelOps) { + tilePLoop(pLoop, tileSize); + } + } + + int64_t tileSize; +}; +} // namespace + +std::unique_ptr mlir::createParallelLoopTilingPass(int64_t tileSize) { + return std::make_unique(tileSize); +} + +static PassRegistration + pass("parallel-loop-tiling", "Tile parallel loops.", + [] { return createParallelLoopTilingPass(clTileSize); }); diff --git a/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir b/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Loops/parallel-loop-tiling.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-tiling)' -split-input-file | FileCheck %s --dump-input-on-failure + +func @parallel_loop(%arg0 : index, %arg1 : index, %arg2 : index, + %arg3 : index, %arg4 : index, %arg5 : index, + %A: memref, %B: memref, + %C: memref, %result: memref) { + loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) + step (%arg4, %arg5) { + %B_elem = load %B[%i0, %i1] : memref + %C_elem = load %C[%i0, %i1] : memref + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %result[%i0, %i1] : memref + } + return +} + +// CHECK: #map0 = affine_map<(d0, d1, d2) -> (d0, d1 - d2)> +// CHECK: func @parallel_loop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: memref, %arg7: memref, %arg8: memref, %arg9: memref) { +// CHECK: %c0 = constant 0 : index +// CHECK: %c4 = constant 4 : index +// CHECK: %0 = muli %arg4, %c4 : index +// CHECK: %1 = muli %arg5, %c4 : index +// CHECK: loop.parallel (%arg10, %arg11) = (%arg0, %arg1) to (%arg2, %arg3) step (%0, %1) { +// CHECK: %2 = affine.min #map0(%c0, %arg2, %arg0) +// CHECK: %3 = affine.min #map0(%c0, %arg3, %arg1) +// CHECK: loop.parallel (%arg12, %arg13) = (%c0, %c0) to (%2, %3) step (%arg4, %arg5) { +// CHECK: %4 = load %arg7[%arg12, %arg13] : memref +// CHECK: %5 = load %arg8[%arg12, %arg13] : memref +// CHECK: %6 = addf %4, %5 : f32 +// CHECK: store %6, %arg9[%arg12, %arg13] : memref +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: return +// CHECK: } + +func @tile_nested_innermost() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + } + return +} + +// CHECK-LABEL: func @tile_nested_innermost() { +// CHECK: %c2 = constant 2 : index +// CHECK: %c0 = constant 0 : index +// CHECK: %c1 = constant 1 : index +// CHECK: loop.parallel (%arg0, %arg1) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { +// CHECK: %c0_0 = constant 0 : index +// CHECK: %c4 = constant 4 : index +// CHECK: %0 = muli %c1, %c4 : index +// CHECK: %1 = muli %c1, %c4 : index +// CHECK: loop.parallel (%arg2, %arg3) = (%c0, %c0) to (%c2, %c2) step (%0, %1) { +// CHECK: %2 = affine.min #map0(%c0_0, %c2, %c0) +// CHECK: %3 = affine.min #map0(%c0_0, %c2, %c0) +// CHECK: loop.parallel (%arg4, %arg5) = (%c0_0, %c0_0) to (%2, %3) step (%c1, %c1) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: loop.parallel (%arg0, %arg1) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: return +// CHECK: } +