diff --git a/mlir/include/mlir/Dialect/Affine/Passes.h b/mlir/include/mlir/Dialect/Affine/Passes.h --- a/mlir/include/mlir/Dialect/Affine/Passes.h +++ b/mlir/include/mlir/Dialect/Affine/Passes.h @@ -73,10 +73,11 @@ /// bounds into a single loop. std::unique_ptr> createLoopCoalescingPass(); -/// Creates a loop fusion pass which fuses loops according to type of fusion +/// Creates a loop fusion pass which fuses affine loop nests at the top-level of +/// the operation the pass is created on according to the type of fusion /// specified in `fusionMode`. Buffers of size less than or equal to /// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`. -std::unique_ptr> +std::unique_ptr createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, bool maximalFusion = false, diff --git a/mlir/include/mlir/Dialect/Affine/Passes.td b/mlir/include/mlir/Dialect/Affine/Passes.td --- a/mlir/include/mlir/Dialect/Affine/Passes.td +++ b/mlir/include/mlir/Dialect/Affine/Passes.td @@ -43,22 +43,24 @@ ]; } -def AffineLoopFusion : Pass<"affine-loop-fusion", "func::FuncOp"> { +def AffineLoopFusion : Pass<"affine-loop-fusion"> { let summary = "Fuse affine loop nests"; let description = [{ - This pass performs fusion of loop nests using a slicing-based approach. It - combines two fusion strategies: producer-consumer fusion and sibling fusion. - Producer-consumer fusion is aimed at fusing pairs of loops where the first - one writes to a memref that the second reads. Sibling fusion targets pairs - of loops that share no dependences between them but that load from the same - memref. The fused loop nests, when possible, are rewritten to access - significantly smaller local buffers instead of the original memref's, and - the latter are often either completely optimized away or contracted. This - transformation leads to enhanced locality and lower memory footprint through - the elimination or contraction of temporaries/intermediate memref's. These - benefits are sometimes achieved at the expense of redundant computation - through a cost model that evaluates available choices such as the depth at - which a source slice should be materialized in the designation slice. + This pass performs fusion of loop nests using a slicing-based approach. The + transformation works on an MLIR `Block` granularity and applies to all + blocks of the pass is run on. It combines two fusion strategies: + producer-consumer fusion and sibling fusion. Producer-consumer fusion is + aimed at fusing pairs of loops where the first one writes to a memref that + the second reads. Sibling fusion targets pairs of loops that share no + dependences between them but that load from the same memref. The fused loop + nests, when possible, are rewritten to access significantly smaller local + buffers instead of the original memref's, and the latter are often either + completely optimized away or contracted. This transformation leads to + enhanced locality and lower memory footprint through the elimination or + contraction of temporaries/intermediate memref's. These benefits are + sometimes achieved at the expense of redundant computation through a cost + model that evaluates available choices such as the depth at which a source + slice should be materialized in the designation slice. Example 1: Producer-consumer fusion. Input: diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp --- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// This file implements loop fusion. +// This file implements affine fusion. // //===----------------------------------------------------------------------===// @@ -19,7 +19,6 @@ #include "mlir/Dialect/Affine/LoopFusionUtils.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Affine/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -65,12 +64,13 @@ this->affineFusionMode = affineFusionMode; } + void runOnBlock(Block *block); void runOnOperation() override; }; } // namespace -std::unique_ptr> +std::unique_ptr mlir::createLoopFusionPass(unsigned fastMemorySpace, uint64_t localBufSizeThreshold, bool maximalFusion, enum FusionMode affineFusionMode) { @@ -104,7 +104,7 @@ }; // MemRefDependenceGraph is a graph data structure where graph nodes are -// top-level operations in a FuncOp which contain load/store ops, and edges +// top-level operations in a `Block` which contain load/store ops, and edges // are memref dependences between the nodes. // TODO: Add a more flexible dependence graph representation. // TODO: Add a depth parameter to dependence graph construction. @@ -207,11 +207,11 @@ // The next unique identifier to use for newly created graph nodes. unsigned nextNodeId = 0; - MemRefDependenceGraph() = default; + MemRefDependenceGraph(Block &block) : block(block) {} // Initializes the dependence graph based on operations in 'f'. // Returns true on success, false otherwise. - bool init(func::FuncOp f); + bool init(Block *block); // Returns the graph node for 'id'. Node *getNode(unsigned id) { @@ -258,7 +258,7 @@ } // Returns true if node 'id' writes to any memref which escapes (or is an - // argument to) the function/block. Returns false otherwise. + // argument to) the block. Returns false otherwise. bool writesToLiveInOrEscapingMemrefs(unsigned id) { Node *node = getNode(id); for (auto *storeOpInst : node->stores) { @@ -267,7 +267,8 @@ // Return true if 'memref' is a block argument. if (!op) return true; - // Return true if any use of 'memref' escapes the function. + // Return true if any use of 'memref' does not deference it in an affine + // way. for (auto *user : memref.getUsers()) if (!isa(*user)) return true; @@ -597,6 +598,9 @@ } } void dump() const { print(llvm::errs()); } + + /// The block for which this graph is created to perform fusion. + Block █ }; /// Returns true if node 'srcId' can be removed after fusing it with node @@ -710,13 +714,14 @@ producerConsumerMemrefs); } -/// A memref escapes the function if either: +/// A memref escapes in the context of the fusion pass if either: /// 1. it (or its alias) is a block argument, or /// 2. created by an op not known to guarantee alias freedom, -/// 3. it (or its alias) is used by a non-affine op (e.g., call op, memref -/// load/store ops, alias creating ops, unknown ops, etc.); such ops -/// do not deference the memref in an affine way. -static bool isEscapingMemref(Value memref) { +/// 3. it (or its alias) are used by ops other than affine dereferencing ops +/// (e.g., by call op, memref load/store ops, alias creating ops, unknown ops, +/// terminator ops, etc.); such ops do not deference the memref in an affine +/// way. +static bool isEscapingMemref(Value memref, Block *block) { Operation *defOp = memref.getDefiningOp(); // Check if 'memref' is a block argument. if (!defOp) @@ -724,7 +729,7 @@ // Check if this is defined to be an alias of another memref. if (auto viewOp = dyn_cast(defOp)) - if (isEscapingMemref(viewOp.getViewSource())) + if (isEscapingMemref(viewOp.getViewSource(), block)) return true; // Any op besides allocating ops wouldn't guarantee alias freedom @@ -733,14 +738,18 @@ // Check if 'memref' is used by a non-deferencing op (including unknown ones) // (e.g., call ops, alias creating ops, etc.). - for (Operation *user : memref.getUsers()) + for (Operation *user : memref.getUsers()) { + // Ignore users outside of `block`. + if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != block) + continue; if (!isa(*user)) return true; + } return false; } /// Returns in 'escapingMemRefs' the memrefs from affine store ops in node 'id' -/// that escape the function or are accessed by non-affine ops. +/// that escape the block or are accessed in a non-affine way. void gatherEscapingMemrefs(unsigned id, MemRefDependenceGraph *mdg, DenseSet &escapingMemRefs) { auto *node = mdg->getNode(id); @@ -748,29 +757,25 @@ auto memref = cast(storeOp).getMemRef(); if (escapingMemRefs.count(memref)) continue; - if (isEscapingMemref(memref)) + if (isEscapingMemref(memref, &mdg->block)) escapingMemRefs.insert(memref); } } } // namespace -// Initializes the data dependence graph by walking operations in 'f'. +// Initializes the data dependence graph by walking operations in `block`. // Assigns each node in the graph a node id based on program order in 'f'. // TODO: Add support for taking a Block arg to construct the // dependence graph at a different depth. -bool MemRefDependenceGraph::init(func::FuncOp f) { +bool MemRefDependenceGraph::init(Block *block) { LLVM_DEBUG(llvm::dbgs() << "--- Initializing MDG ---\n"); // Map from a memref to the set of ids of the nodes that have ops accessing // the memref. DenseMap> memrefAccesses; - // TODO: support multi-block functions. - if (!llvm::hasSingleElement(f)) - return false; - DenseMap forToNodeMap; - for (auto &op : f.front()) { + for (Operation &op : *block) { if (auto forOp = dyn_cast(op)) { // Create graph node 'id' to represent top-level 'forOp' and record // all loads and store accesses it contains. @@ -845,14 +850,18 @@ // Stores don't define SSA values, skip them. if (!node.stores.empty()) continue; - auto *opInst = node.op; - for (auto value : opInst->getResults()) { - for (auto *user : value.getUsers()) { + Operation *opInst = node.op; + for (Value value : opInst->getResults()) { + for (Operation *user : value.getUsers()) { + // Ignore users outside of the block. + if (block->getParent()->findAncestorOpInRegion(*user)->getBlock() != + block) + continue; SmallVector loops; getLoopIVs(*user, &loops); if (loops.empty()) continue; - assert(forToNodeMap.count(loops[0]) > 0); + assert(forToNodeMap.count(loops[0]) > 0 && "missing mapping"); unsigned userLoopNestId = forToNodeMap[loops[0]]; addEdge(node.id, userLoopNestId, value); } @@ -918,7 +927,7 @@ // Create builder to insert alloc op just before 'forOp'. OpBuilder b(forInst); // Builder to create constants at the top level. - OpBuilder top(forInst->getParentOfType().getBody()); + OpBuilder top(forInst->getParentRegion()); // Create new memref type based on slice bounds. auto oldMemRef = cast(srcStoreOpInst).getMemRef(); auto oldMemRefType = oldMemRef.getType().cast(); @@ -979,7 +988,7 @@ // a constant shape. // TODO: Create/move alloc ops for private memrefs closer to their // consumer loop nests to reduce their live range. Currently they are added - // at the beginning of the function, because loop nests can be reordered + // at the beginning of the block, because loop nests can be reordered // during the fusion pass. Value newMemRef = top.create(forOp.getLoc(), newMemRefType); @@ -1508,8 +1517,8 @@ })) continue; - // Gather memrefs in 'srcNode' that are written and escape to the - // function (e.g., memref function arguments, returned memrefs, + // Gather memrefs in 'srcNode' that are written and escape out of the + // block (e.g., memref block arguments, returned memrefs, // memrefs passed to function calls, etc.). DenseSet srcEscapingMemRefs; gatherEscapingMemrefs(srcNode->id, mdg, srcEscapingMemRefs); @@ -1829,7 +1838,7 @@ } } - // Searches function argument uses and the graph from 'dstNode' looking for a + // Searches block argument uses and the graph from 'dstNode' looking for a // fusion candidate sibling node which shares no dependences with 'dstNode' // but which loads from the same memref. Returns true and sets // 'idAndMemrefToFuse' on success. Returns false otherwise. @@ -1874,36 +1883,37 @@ return true; }; - // Search for siblings which load the same memref function argument. - auto fn = dstNode->op->getParentOfType(); - for (unsigned i = 0, e = fn.getNumArguments(); i != e; ++i) { - for (auto *user : fn.getArgument(i).getUsers()) { - if (auto loadOp = dyn_cast(user)) { - // Gather loops surrounding 'use'. - SmallVector loops; - getLoopIVs(*user, &loops); - // Skip 'use' if it is not within a loop nest. - if (loops.empty()) - continue; - Node *sibNode = mdg->getForOpNode(loops[0]); - assert(sibNode != nullptr); - // Skip 'use' if it not a sibling to 'dstNode'. - if (sibNode->id == dstNode->id) - continue; - // Skip 'use' if it has been visited. - if (visitedSibNodeIds->count(sibNode->id) > 0) - continue; - // Skip 'use' if it does not load from the same memref as 'dstNode'. - auto memref = loadOp.getMemRef(); - if (dstNode->getLoadOpCount(memref) == 0) - continue; - // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. - if (canFuseWithSibNode(sibNode, memref)) { - visitedSibNodeIds->insert(sibNode->id); - idAndMemrefToFuse->first = sibNode->id; - idAndMemrefToFuse->second = memref; - return true; - } + // Search for siblings which load the same memref block argument. + Block *block = dstNode->op->getBlock(); + for (unsigned i = 0, e = block->getNumArguments(); i != e; ++i) { + for (Operation *user : block->getArgument(i).getUsers()) { + auto loadOp = dyn_cast(user); + if (!loadOp) + continue; + // Gather loops surrounding 'use'. + SmallVector loops; + getLoopIVs(*user, &loops); + // Skip 'use' if it is not within a loop nest. + if (loops.empty()) + continue; + Node *sibNode = mdg->getForOpNode(loops[0]); + assert(sibNode != nullptr); + // Skip 'use' if it not a sibling to 'dstNode'. + if (sibNode->id == dstNode->id) + continue; + // Skip 'use' if it has been visited. + if (visitedSibNodeIds->count(sibNode->id) > 0) + continue; + // Skip 'use' if it does not load from the same memref as 'dstNode'. + auto memref = loadOp.getMemRef(); + if (dstNode->getLoadOpCount(memref) == 0) + continue; + // Check if 'sibNode/dstNode' can be input-reuse fused on 'memref'. + if (canFuseWithSibNode(sibNode, memref)) { + visitedSibNodeIds->insert(sibNode->id); + idAndMemrefToFuse->first = sibNode->id; + idAndMemrefToFuse->second = memref; + return true; } } } @@ -1968,8 +1978,7 @@ mdg->addToNode(dstNode->id, dstLoopCollector.loadOpInsts, dstLoopCollector.storeOpInsts); // Remove old sibling loop nest if it no longer has outgoing dependence - // edges, and it does not write to a memref which escapes the - // function. + // edges, and it does not write to a memref which escapes the block. if (mdg->getOutEdgeCount(sibNode->id) == 0) { Operation *op = sibNode->op; mdg->removeNode(sibNode->id); @@ -1996,9 +2005,10 @@ } // namespace -void LoopFusion::runOnOperation() { - MemRefDependenceGraph g; - if (!g.init(getOperation())) +/// Run fusion on `block`. +void LoopFusion::runOnBlock(Block *block) { + MemRefDependenceGraph g(*block); + if (!g.init(block)) return; Optional fastMemorySpaceOpt; @@ -2015,3 +2025,9 @@ else fusion.runGreedyFusion(); } + +void LoopFusion::runOnOperation() { + for (Region ®ion : getOperation()->getRegions()) + for (Block &block : region.getBlocks()) + runOnBlock(&block); +} diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp --- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp @@ -17,7 +17,6 @@ #include "mlir/Dialect/Affine/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Operation.h" #include "llvm/Support/Debug.h" @@ -471,7 +470,7 @@ auto walkResult = forOpRoot.walk([&](AffineForOp forOp) { auto *childForOp = forOp.getOperation(); auto *parentForOp = forOp->getParentOp(); - if (!llvm::isa(parentForOp)) { + if (forOp != forOpRoot) { if (!isa(parentForOp)) { LLVM_DEBUG(llvm::dbgs() << "Expected parent AffineForOp\n"); return WalkResult::interrupt(); diff --git a/mlir/test/Transforms/loop-fusion-2.mlir b/mlir/test/Transforms/loop-fusion-2.mlir --- a/mlir/test/Transforms/loop-fusion-2.mlir +++ b/mlir/test/Transforms/loop-fusion-2.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=MAXIMAL // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir diff --git a/mlir/test/Transforms/loop-fusion-3.mlir b/mlir/test/Transforms/loop-fusion-3.mlir --- a/mlir/test/Transforms/loop-fusion-3.mlir +++ b/mlir/test/Transforms/loop-fusion-3.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=MAXIMAL // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir @@ -532,8 +532,8 @@ %2 = arith.divf %0, %1 : f32 } - // Loops '%i0' and '%i1' should be fused even though there is a defining - // node between the loops. It is because the node has no dependence from '%i0'. + // Loops '%i0' and '%i1' should be fused even though there is a defining node + // between the loops. It is because the node has no dependence from '%i0'. // CHECK: affine.load %{{.*}}[] : memref // CHECK-NEXT: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.load %{{.*}}[] : memref @@ -561,8 +561,8 @@ %2 = arith.divf %0, %1 : f32 } - // Loops '%i0' and '%i1' should not be fused because the defining node - // of '%0' used in '%i1' has dependence from loop '%i0'. + // Loops '%i0' and '%i1' should not be fused because the defining node of '%0' + // used in '%i1' has dependence from loop '%i0'. // CHECK: affine.for %{{.*}} = 0 to 10 { // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[] : memref // CHECK-NEXT: affine.store %{{.*}}, %{{.*}}[%{{.*}}] : memref<10xf32> diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal mode=sibling" -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL // Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir @@ -141,3 +141,36 @@ // SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { // SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { // SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { + +// ----- + +// PRODUCER-CONSUMER-LABEL: func @fusion_for_multiple_blocks() { +func.func @fusion_for_multiple_blocks() { +^bb0: + %m = memref.alloc() : memref<10xf32> + %cf7 = arith.constant 7.0 : f32 + + affine.for %i0 = 0 to 10 { + affine.store %cf7, %m[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %m[%i1] : memref<10xf32> + } + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 10 { + // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: } + cf.br ^bb1 +^bb1: + affine.for %i0 = 0 to 10 { + affine.store %cf7, %m[%i0] : memref<10xf32> + } + affine.for %i1 = 0 to 10 { + %v0 = affine.load %m[%i1] : memref<10xf32> + } + // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 10 { + // PRODUCER-CONSUMER-NEXT: affine.store %{{.*}}, %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: affine.load %{{.*}}[0] : memref<1xf32> + // PRODUCER-CONSUMER-NEXT: } + return +} diff --git a/mlir/test/Transforms/loop-fusion.mlir b/mlir/test/Transforms/loop-fusion.mlir --- a/mlir/test/Transforms/loop-fusion.mlir +++ b/mlir/test/Transforms/loop-fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion))' -split-input-file | FileCheck %s // Part II of fusion tests in mlir/test/Transforms/loop-fusion=2.mlir. // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir