diff --git a/mlir/include/mlir/Dialect/LoopOps/Transforms.h b/mlir/include/mlir/Dialect/LoopOps/Transforms.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/Transforms.h @@ -0,0 +1,48 @@ +//===- Transforms.h - Pass Entrypoints --------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines transformations on loop operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LOOPOPS_TRANSFORMS_H_ +#define MLIR_DIALECT_LOOPOPS_TRANSFORMS_H_ + +#include "llvm/ADT/ArrayRef.h" + +namespace mlir { + +class Region; + +namespace loop { + +class ParallelOp; + +/// Fuses all adjacent loop.parallel operations with identical bounds and step +/// into one loop.parallel operations. Uses a naive aliasing and dependency +/// analysis. +void naivelyFuseParallelOps(Region ®ion); + +/// Tile a parallel loop of the form +/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4, %arg5) +/// +/// into +/// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) +/// step (%arg4*tileSize[0], +/// %arg5*tileSize[1]) +/// loop.parallel (%j0, %j1) = (0, 0) to (min(tileSize[0], %arg2-%j0) +/// min(tileSize[1], %arg3-%j1)) +/// step (%arg4, %arg5) +/// The old loop is replaced with the new one. +void tileParallelLoop(ParallelOp op, ArrayRef tileSizes); + +} // namespace loop +} // namespace mlir + +#endif // MLIR_DIALECT_LOOPOPS_TRANSFORMS_H_ diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/LoopOps/Transforms.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" @@ -20,7 +21,7 @@ #include "mlir/Transforms/Passes.h" using namespace mlir; -using loop::ParallelOp; +using namespace mlir::loop; /// Verify there are no nested ParallelOps. static bool hasNestedParallelOp(ParallelOp ploop) { @@ -135,32 +136,30 @@ firstPloop.erase(); } -static void naivelyFuseParallelOps(Operation *op) { - OpBuilder b(op); +void mlir::loop::naivelyFuseParallelOps(Region ®ion) { + OpBuilder b(region); // Consider every single block and attempt to fuse adjacent loops. - for (auto ®ion : op->getRegions()) { - for (auto &block : region.getBlocks()) { - SmallVector, 1> ploop_chains{{}}; - // Not using `walk()` to traverse only top-level parallel loops and also - // make sure that there are no side-effecting ops between the parallel - // loops. - bool noSideEffects = true; - for (auto &op : block.getOperations()) { - if (auto ploop = dyn_cast(op)) { - if (noSideEffects) { - ploop_chains.back().push_back(ploop); - } else { - ploop_chains.push_back({ploop}); - noSideEffects = true; - } - continue; + for (auto &block : region.getBlocks()) { + SmallVector, 1> ploopChains{{}}; + // Not using `walk()` to traverse only top-level parallel loops and also + // make sure that there are no side-effecting ops between the parallel + // loops. + bool noSideEffects = true; + for (auto &op : block.getOperations()) { + if (auto ploop = dyn_cast(op)) { + if (noSideEffects) { + ploopChains.back().push_back(ploop); + } else { + ploopChains.push_back({ploop}); + noSideEffects = true; } - noSideEffects &= op.hasNoSideEffect(); - } - for (ArrayRef ploops : ploop_chains) { - for (int i = 0, e = ploops.size(); i + 1 < e; ++i) - fuseIfLegal(ploops[i], ploops[i + 1], b); + continue; } + noSideEffects &= op.hasNoSideEffect(); + } + for (ArrayRef ploops : ploopChains) { + for (int i = 0, e = ploops.size(); i + 1 < e; ++i) + fuseIfLegal(ploops[i], ploops[i + 1], b); } } } @@ -168,7 +167,10 @@ namespace { struct ParallelLoopFusion : public OperationPass { - void runOnOperation() override { naivelyFuseParallelOps(getOperation()); } + void runOnOperation() override { + for (Region ®ion : getOperation()->getRegions()) + naivelyFuseParallelOps(region); + } }; } // namespace diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp --- a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopTiling.cpp @@ -13,13 +13,14 @@ #include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/LoopOps/Transforms.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/Support/CommandLine.h" using namespace mlir; -using loop::ParallelOp; +using namespace mlir::loop; /// Tile a parallel loop of the form /// loop.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3) @@ -33,7 +34,7 @@ /// min(tileSize[1], %arg3-%j1)) /// step (%arg4, %arg5) /// The old loop is replaced with the new one. -static void tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { +void mlir::loop::tileParallelLoop(ParallelOp op, ArrayRef tileSizes) { OpBuilder b(op); auto zero = b.create(op.getLoc(), 0); SmallVector tileSizeConstants;