diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -190,6 +190,7 @@ iterator_range getInductionVars() { return {getBody()->args_begin(), getBody()->args_end()}; } + unsigned getNumLoops() { return step().size(); } }]; } diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -74,6 +74,9 @@ uint64_t localBufSizeThreshold = 0, bool maximalFusion = false); +/// Creates a loop fusion pass which fuses parallel loops. +std::unique_ptr> createParallelLoopFusionPass(); + /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. std::unique_ptr createLoopInvariantCodeMotionPass(); diff --git a/mlir/lib/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Transforms/ParallelLoopFusion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Transforms/ParallelLoopFusion.cpp @@ -0,0 +1,137 @@ +//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop fusion on parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +namespace mlir { +namespace { + +using loop::ParallelOp; + +// Verify there are no nested ParallelOps. +bool hasNestedParallelOp(ParallelOp ploop) { + auto walkResult = ploop.getBody()->walk( + [](ParallelOp) -> WalkResult { return WalkResult::interrupt(); }); + return walkResult.wasInterrupted(); +} + +// Verify equal iteration spaces. +bool equalIterationSpaces(ParallelOp ploop1, ParallelOp ploop2) { + if (ploop1.getNumLoops() != ploop2.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + for (auto item : llvm::zip(lhs, rhs)) { + Value p1, p2; + std::tie(p1, p2) = item; + // TODO(pifon): Extend this to support aliases and equal constants. + if (p1 != p2) + return false; + } + return true; + }; + return matchOperands(ploop1.lowerBound(), ploop2.lowerBound()) && + matchOperands(ploop1.upperBound(), ploop2.upperBound()) && + matchOperands(ploop1.step(), ploop2.step()); +} + +// Analyzes dependencies in the most primitive way by checking that ploop1 +// writes to the same buffer elements that ploop2 reads. +bool verifyDependencies(ParallelOp ploop1, ParallelOp ploop2, + const BlockAndValueMapping &map) { + DenseMap> bufferStores; + ploop1.getBody()->walk([&](StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.indices()); + }); + auto walkResult = ploop2.getBody()->walk([&](LoadOp load) -> WalkResult { + auto write = bufferStores.find(load.getMemRef()); + if (write == bufferStores.end()) + return WalkResult::advance(); + + // Allow only single write access per buffer. + if (write->second.size() != 1) + return WalkResult::interrupt(); + + // Check that the load indices of ploop2 coincide with store indices of + // ploop1 for the same memrefs. + auto storeIndices = write->second.front(); + auto loadIndices = load.indices(); + if (storeIndices.size() != loadIndices.size()) { + return WalkResult::interrupt(); + } + for (int i = 0, e = storeIndices.size(); i < e; ++i) { + if (map.lookupOrDefault(storeIndices[i]) != loadIndices[i]) { + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + bool result = !walkResult.wasInterrupted(); + return result; +} + +bool isFusionLegal(ParallelOp ploop1, ParallelOp ploop2, + const BlockAndValueMapping &map) { + return !hasNestedParallelOp(ploop2) && !hasNestedParallelOp(ploop1) && + equalIterationSpaces(ploop1, ploop2) && + verifyDependencies(ploop1, ploop2, map); +} + +// Prepends operations of ploop1's body into ploop2's body. +void fuseIfLegal(ParallelOp ploop1, ParallelOp ploop2, OpBuilder b) { + BlockAndValueMapping map; + map.map(ploop1.getBody()->getArguments(), ploop2.getBody()->getArguments()); + + if (!isFusionLegal(ploop1, ploop2, map)) + return; + + b.setInsertionPointToStart(ploop2.getBody()); + for (auto &op : ploop1.getBody()->without_terminator()) + b.clone(op, map); + ploop1.erase(); +} + +void naivelyFuseParallelOps(FuncOp f) { + OpBuilder b(f); + + // Consider every single block and attempt to fuse adjacent loops. + for (auto &block : f.getBody().getBlocks()) { + SmallVector ploops; + // Not using `walk()` to traverse only top-level parallel loops. + for (auto &op : block.without_terminator()) { + if (auto ploop = dyn_cast(op)) { + ploops.push_back(ploop); + } + } + // Iteratively fuse adjacent loops. + for (int i = 0, e = ploops.size(); i + 1 < e; ++i) { + fuseIfLegal(ploops[i], ploops[i + 1], b); + } + } +} + +struct LoopFusion : public FunctionPass { + void runOnFunction() override { naivelyFuseParallelOps(getFunction()); } +}; + +} // namespace +} // namespace mlir + +static mlir::PassRegistration + pass("parallel-loop-fusion", "Fuse parallel loop nests"); diff --git a/mlir/test/Transforms/parallel-loop-fusion.mlir b/mlir/test/Transforms/parallel-loop-fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Transforms/parallel-loop-fusion.mlir @@ -0,0 +1,238 @@ +// RUN: mlir-opt %s -parallel-loop-fusion -split-input-file | FileCheck %s --dump-input-on-failure + +func @fuse_empty_loops() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @fuse_empty_loops +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK-NOT: loop.parallel + +// ----- + +func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, + %C: memref<2x2xf32>, %result: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %sum = alloc() {temp = true} : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %C[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = load %sum[%i, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %result[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @fuse_two +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, +// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[SUM:%.*]] = alloc() {temp = true} +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]] +// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: dealloc [[SUM]] + +// ----- + +func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, + %result: memref<100x10xf32>) { + %c100 = constant 100 : index + %c10 = constant 10 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %broadcast_rhs = alloc() {temp = true} : memref<100x10xf32> + %diff = alloc() {temp = true} : memref<100x10xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %rhs_elem = load %rhs[%i] : memref<100xf32> + store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32> + %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32> + %diff_elem = subf %lhs_elem, %broadcast_rhs_elem : f32 + store %diff_elem, %diff[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %diff_elem = load %diff[%i, %j] : memref<100x10xf32> + %exp_elem = exp %diff_elem : f32 + store %exp_elem, %result[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + dealloc %broadcast_rhs : memref<100x10xf32> + dealloc %diff : memref<100x10xf32> + return +} +// CHECK-LABEL: func @fuse_three +// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>, +// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) { +// CHECK: [[C100:%.*]] = constant 100 : index +// CHECK: [[C10:%.*]] = constant 10 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[BROADCAST_RHS:%.*]] = alloc() {temp = true} +// CHECK: [[DIFF:%.*]] = alloc() {temp = true} +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) { +// CHECK: [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]] +// CHECK: store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]] +// CHECK: [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]] +// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]] +// CHECK: [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]] +// CHECK: store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]] +// CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]] +// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]] +// CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: dealloc [[BROADCAST_RHS]] +// CHECK: dealloc [[DIFF]] + +// ----- + +func @do_not_fuse_nested_ploop1() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_nested_ploop1 +// CHECK: loop.parallel +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_nested_ploop2() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_nested_ploop2 +// CHECK: loop.parallel +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_unmatching_num_loops() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i) = (%c0) to (%c2) step (%c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_unmatching_iteration_space() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_unmatching_read_write_patterns( + %A: memref<2x2xf32>, %B: memref<2x2xf32>, + %C: memref<2x2xf32>, %result: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %sum = alloc() {temp = true} : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %C[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %k = addi %i, %c1 : index + %sum_elem = load %sum[%k, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %result[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns +// CHECK: loop.parallel +// CHECK: loop.parallel