diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -190,6 +190,7 @@ iterator_range getInductionVars() { return {getBody()->args_begin(), getBody()->args_end()}; } + unsigned getNumLoops() { return step().size(); } }]; } diff --git a/mlir/include/mlir/Dialect/LoopOps/Passes.h b/mlir/include/mlir/Dialect/LoopOps/Passes.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/Passes.h @@ -0,0 +1,27 @@ +//===- Passes.h - Pass Entrypoints ------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header file defines prototypes that expose pass constructors. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LOOPOPS_PASSES_H_ +#define MLIR_DIALECT_LOOPOPS_PASSES_H_ + +#include + +namespace mlir { + +class Pass; + +/// Creates a loop fusion pass which fuses parallel loops. +std::unique_ptr createParallelLoopFusionPass(); + +} // namespace mlir + +#endif // MLIR_DIALECT_LOOPOPS_PASSES_H_ diff --git a/mlir/include/mlir/InitAllPasses.h b/mlir/include/mlir/InitAllPasses.h --- a/mlir/include/mlir/InitAllPasses.h +++ b/mlir/include/mlir/InitAllPasses.h @@ -27,6 +27,7 @@ #include "mlir/Dialect/FxpMathOps/Passes.h" #include "mlir/Dialect/GPU/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Dialect/LoopOps/Passes.h" #include "mlir/Dialect/QuantOps/Passes.h" #include "mlir/Dialect/SPIRV/Passes.h" #include "mlir/Quantizer/Transforms/Passes.h" @@ -104,6 +105,9 @@ createConvertLinalgToAffineLoopsPass(); createConvertLinalgToLLVMPass(); + // LoopOps + createParallelLoopFusionPass(); + // QuantOps quant::createConvertSimulatedQuantPass(); quant::createConvertConstPass(); diff --git a/mlir/lib/Dialect/LoopOps/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/CMakeLists.txt --- a/mlir/lib/Dialect/LoopOps/CMakeLists.txt +++ b/mlir/lib/Dialect/LoopOps/CMakeLists.txt @@ -21,3 +21,5 @@ MLIRStandardOps LLVMSupport ) + +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/CMakeLists.txt @@ -0,0 +1,11 @@ +add_llvm_library(MLIRLoopOpsTransforms + ParallelLoopFusion.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LoopOps + ) + +target_link_libraries(MLIRLoopOpsTransforms + MLIRPass + MLIRLoopOps + ) diff --git a/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/Transforms/ParallelLoopFusion.cpp @@ -0,0 +1,182 @@ +//===- ParallelLoopFusion.cpp - Code to perform loop fusion ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements loop fusion on parallel loops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/Dialect/LoopOps/Passes.h" +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/BlockAndValueMapping.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" + +using namespace mlir; +using loop::ParallelOp; + +/// Verify there are no nested ParallelOps. +static bool hasNestedParallelOp(ParallelOp ploop) { + auto walkResult = + ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); }); + return walkResult.wasInterrupted(); +} + +/// Verify equal iteration spaces. +static bool equalIterationSpaces(ParallelOp firstPloop, + ParallelOp secondPloop) { + if (firstPloop.getNumLoops() != secondPloop.getNumLoops()) + return false; + + auto matchOperands = [&](const OperandRange &lhs, + const OperandRange &rhs) -> bool { + // TODO: Extend this to support aliases and equal constants. + return std::equal(lhs.begin(), lhs.end(), rhs.begin()); + }; + return matchOperands(firstPloop.lowerBound(), secondPloop.lowerBound()) && + matchOperands(firstPloop.upperBound(), secondPloop.upperBound()) && + matchOperands(firstPloop.step(), secondPloop.step()); +} + +/// Returns true if the defining operation for the memref is inside the body +/// of parallel loop. +bool isDefinedInPloopBody(Value memref, ParallelOp ploop) { + auto *memrefDef = memref.getDefiningOp(); + return memrefDef && ploop.getOperation()->isAncestor(memrefDef); +} + +// Checks if the parallel loops have mixed access to the same buffers. Returns +// `true` if the first parallel loop writes to the same indices that the second +// loop reads. +static bool haveNoReadsAfterWriteExceptSameIndex( + ParallelOp firstPloop, ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + DenseMap> bufferStores; + firstPloop.getBody()->walk([&](StoreOp store) { + bufferStores[store.getMemRef()].push_back(store.indices()); + }); + auto walkResult = secondPloop.getBody()->walk([&](LoadOp load) { + // Stop if the memref is defined in secondPloop body. Careful alias analysis + // is needed. + auto *memrefDef = load.getMemRef().getDefiningOp(); + if (memrefDef && memrefDef->getBlock() == load.getOperation()->getBlock()) + return WalkResult::interrupt(); + + auto write = bufferStores.find(load.getMemRef()); + if (write == bufferStores.end()) + return WalkResult::advance(); + + // Allow only single write access per buffer. + if (write->second.size() != 1) + return WalkResult::interrupt(); + + // Check that the load indices of secondPloop coincide with store indices of + // firstPloop for the same memrefs. + auto storeIndices = write->second.front(); + auto loadIndices = load.indices(); + if (storeIndices.size() != loadIndices.size()) + return WalkResult::interrupt(); + for (int i = 0, e = storeIndices.size(); i < e; ++i) { + if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) != + loadIndices[i]) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +/// Analyzes dependencies in the most primitive way by checking simple read and +/// write patterns. +static LogicalResult +verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + if (!haveNoReadsAfterWriteExceptSameIndex(firstPloop, secondPloop, + firstToSecondPloopIndices)) + return failure(); + + BlockAndValueMapping secondToFirstPloopIndices; + secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(), + firstPloop.getBody()->getArguments()); + return success(haveNoReadsAfterWriteExceptSameIndex( + secondPloop, firstPloop, secondToFirstPloopIndices)); +} + +static bool +isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop, + const BlockAndValueMapping &firstToSecondPloopIndices) { + return !hasNestedParallelOp(firstPloop) && + !hasNestedParallelOp(secondPloop) && + equalIterationSpaces(firstPloop, secondPloop) && + succeeded(verifyDependencies(firstPloop, secondPloop, + firstToSecondPloopIndices)); +} + +/// Prepends operations of firstPloop's body into secondPloop's body. +static void fuseIfLegal(ParallelOp firstPloop, ParallelOp secondPloop, + OpBuilder b) { + BlockAndValueMapping firstToSecondPloopIndices; + firstToSecondPloopIndices.map(firstPloop.getBody()->getArguments(), + secondPloop.getBody()->getArguments()); + + if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices)) + return; + + b.setInsertionPointToStart(secondPloop.getBody()); + for (auto &op : firstPloop.getBody()->without_terminator()) + b.clone(op, firstToSecondPloopIndices); + firstPloop.erase(); +} + +static void naivelyFuseParallelOps(Operation *op) { + OpBuilder b(op); + // Consider every single block and attempt to fuse adjacent loops. + for (auto ®ion : op->getRegions()) { + for (auto &block : region.getBlocks()) { + SmallVector, 1> ploop_chains{{}}; + // Not using `walk()` to traverse only top-level parallel loops and also + // make sure that there are no side-effecting ops between the parallel + // loops. + bool noSideEffects = true; + for (auto &op : block.getOperations()) { + if (auto ploop = dyn_cast(op)) { + if (noSideEffects) { + ploop_chains.back().push_back(ploop); + } else { + ploop_chains.push_back({ploop}); + noSideEffects = true; + } + continue; + } + noSideEffects &= op.hasNoSideEffect(); + } + for (ArrayRef ploops : ploop_chains) { + llvm::errs() << "poo size = " << ploops.size() << '\n'; + for (int i = 0, e = ploops.size(); i + 1 < e; ++i) + fuseIfLegal(ploops[i], ploops[i + 1], b); + } + } + } +} + +namespace { + +struct ParallelLoopFusion : public OperationPass { + void runOnOperation() override { naivelyFuseParallelOps(getOperation()); } +}; + +} // namespace + +std::unique_ptr mlir::createParallelLoopFusionPass() { + return std::make_unique(); +} + +static PassRegistration + pass("parallel-loop-fusion", "Fuse adjacent parallel loops."); diff --git a/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Loops/parallel-loop-fusion.mlir @@ -0,0 +1,309 @@ +// RUN: mlir-opt %s -pass-pipeline='func(parallel-loop-fusion)' -split-input-file | FileCheck %s --dump-input-on-failure + +func @fuse_empty_loops() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @fuse_empty_loops +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK-NOT: loop.parallel + +// ----- + +func @fuse_two(%A: memref<2x2xf32>, %B: memref<2x2xf32>, + %C: memref<2x2xf32>, %result: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %sum = alloc() : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %C[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %sum_elem = load %sum[%i, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %result[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @fuse_two +// CHECK-SAME: ([[A:%.*]]: {{.*}}, [[B:%.*]]: {{.*}}, [[C:%.*]]: {{.*}}, +// CHECK-SAME: [[RESULT:%.*]]: {{.*}}) { +// CHECK: [[C2:%.*]] = constant 2 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[SUM:%.*]] = alloc() +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C2]], [[C2]]) step ([[C1]], [[C1]]) { +// CHECK: [[B_ELEM:%.*]] = load [[B]]{{\[}}[[I]], [[J]]] +// CHECK: [[C_ELEM:%.*]] = load [[C]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM:%.*]] = addf [[B_ELEM]], [[C_ELEM]] +// CHECK: store [[SUM_ELEM]], [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[SUM_ELEM_:%.*]] = load [[SUM]]{{\[}}[[I]], [[J]]] +// CHECK: [[A_ELEM:%.*]] = load [[A]]{{\[}}[[I]], [[J]]] +// CHECK: [[PRODUCT_ELEM:%.*]] = mulf [[SUM_ELEM_]], [[A_ELEM]] +// CHECK: store [[PRODUCT_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: dealloc [[SUM]] + +// ----- + +func @fuse_three(%lhs: memref<100x10xf32>, %rhs: memref<100xf32>, + %result: memref<100x10xf32>) { + %c100 = constant 100 : index + %c10 = constant 10 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %broadcast_rhs = alloc() : memref<100x10xf32> + %diff = alloc() : memref<100x10xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %rhs_elem = load %rhs[%i] : memref<100xf32> + store %rhs_elem, %broadcast_rhs[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %lhs_elem = load %lhs[%i, %j] : memref<100x10xf32> + %broadcast_rhs_elem = load %broadcast_rhs[%i, %j] : memref<100x10xf32> + %diff_elem = subf %lhs_elem, %broadcast_rhs_elem : f32 + store %diff_elem, %diff[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c100, %c10) step (%c1, %c1) { + %diff_elem = load %diff[%i, %j] : memref<100x10xf32> + %exp_elem = exp %diff_elem : f32 + store %exp_elem, %result[%i, %j] : memref<100x10xf32> + "loop.terminator"() : () -> () + } + dealloc %broadcast_rhs : memref<100x10xf32> + dealloc %diff : memref<100x10xf32> + return +} +// CHECK-LABEL: func @fuse_three +// CHECK-SAME: ([[LHS:%.*]]: memref<100x10xf32>, [[RHS:%.*]]: memref<100xf32>, +// CHECK-SAME: [[RESULT:%.*]]: memref<100x10xf32>) { +// CHECK: [[C100:%.*]] = constant 100 : index +// CHECK: [[C10:%.*]] = constant 10 : index +// CHECK: [[C0:%.*]] = constant 0 : index +// CHECK: [[C1:%.*]] = constant 1 : index +// CHECK: [[BROADCAST_RHS:%.*]] = alloc() +// CHECK: [[DIFF:%.*]] = alloc() +// CHECK: loop.parallel ([[I:%.*]], [[J:%.*]]) = ([[C0]], [[C0]]) +// CHECK-SAME: to ([[C100]], [[C10]]) step ([[C1]], [[C1]]) { +// CHECK: [[RHS_ELEM:%.*]] = load [[RHS]]{{\[}}[[I]]] +// CHECK: store [[RHS_ELEM]], [[BROADCAST_RHS]]{{\[}}[[I]], [[J]]] +// CHECK: [[LHS_ELEM:%.*]] = load [[LHS]]{{\[}}[[I]], [[J]]] +// CHECK: [[BROADCAST_RHS_ELEM:%.*]] = load [[BROADCAST_RHS]] +// CHECK: [[DIFF_ELEM:%.*]] = subf [[LHS_ELEM]], [[BROADCAST_RHS_ELEM]] +// CHECK: store [[DIFF_ELEM]], [[DIFF]]{{\[}}[[I]], [[J]]] +// CHECK: [[DIFF_ELEM_:%.*]] = load [[DIFF]]{{\[}}[[I]], [[J]]] +// CHECK: [[EXP_ELEM:%.*]] = exp [[DIFF_ELEM_]] +// CHECK: store [[EXP_ELEM]], [[RESULT]]{{\[}}[[I]], [[J]]] +// CHECK: "loop.terminator"() : () -> () +// CHECK: } +// CHECK: dealloc [[BROADCAST_RHS]] +// CHECK: dealloc [[DIFF]] + +// ----- + +func @do_not_fuse_nested_ploop1() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_nested_ploop1 +// CHECK: loop.parallel +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_nested_ploop2() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + loop.parallel (%k, %l) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_nested_ploop2 +// CHECK: loop.parallel +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_unmatching_num_loops() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i) = (%c0) to (%c2) step (%c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_unmatching_num_loops +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_with_side_effecting_ops_in_between() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + %buffer = alloc() : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_with_side_effecting_ops_in_between +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_unmatching_iteration_space() { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %c4 = constant 4 : index + loop.parallel (%i, %j) = (%c0, %c0) to (%c4, %c4) step (%c2, %c2) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_unmatching_iteration_space +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_unmatching_write_read_patterns( + %A: memref<2x2xf32>, %B: memref<2x2xf32>, + %C: memref<2x2xf32>, %result: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %common_buf = alloc() : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %C[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %common_buf[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %k = addi %i, %c1 : index + %sum_elem = load %common_buf[%k, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %result[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + dealloc %common_buf : memref<2x2xf32> + return +} +// CHECK-LABEL: func @do_not_fuse_unmatching_write_read_patterns +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_unmatching_read_write_patterns( + %A: memref<2x2xf32>, %B: memref<2x2xf32>, %common_buf: memref<2x2xf32>) { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %sum = alloc() : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %B_elem = load %B[%i, %j] : memref<2x2xf32> + %C_elem = load %common_buf[%i, %j] : memref<2x2xf32> + %sum_elem = addf %B_elem, %C_elem : f32 + store %sum_elem, %sum[%i, %j] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %k = addi %i, %c1 : index + %sum_elem = load %sum[%k, %j] : memref<2x2xf32> + %A_elem = load %A[%i, %j] : memref<2x2xf32> + %product_elem = mulf %sum_elem, %A_elem : f32 + store %product_elem, %common_buf[%j, %i] : memref<2x2xf32> + "loop.terminator"() : () -> () + } + dealloc %sum : memref<2x2xf32> + return +} +// CHECK-LABEL: func @do_not_fuse_unmatching_read_write_patterns +// CHECK: loop.parallel +// CHECK: loop.parallel + +// ----- + +func @do_not_fuse_loops_with_memref_defined_in_loop_bodies() { + %c2 = constant 2 : index + %c0 = constant 0 : index + %c1 = constant 1 : index + %buffer = alloc() : memref<2x2xf32> + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + "loop.terminator"() : () -> () + } + loop.parallel (%i, %j) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) { + %A = subview %buffer[%c0, %c0][%c2, %c2][%c1, %c1] + : memref<2x2xf32> to memref + %A_elem = load %A[%i, %j] : memref + "loop.terminator"() : () -> () + } + "xla_lhlo.terminator"() : () -> () +} +// CHECK-LABEL: func @do_not_fuse_loops_with_memref_defined_in_loop_bodies +// CHECK: loop.parallel +// CHECK: loop.parallel diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -19,6 +19,7 @@ ) set(LIBS + MLIRLoopOpsTransforms MLIRLoopAnalysis MLIRAnalysis MLIRAffineOps