diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -25,6 +25,10 @@ class AffineForOp; class GreedyRewriteConfig; +/// Fusion mode to attempt. The default mode `Greedy` does both +/// producer-consumer and sibling fusion. +enum FusionMode { Greedy, ProducerConsumer, Sibling }; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// @@ -72,13 +76,14 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); -/// Creates a loop fusion pass which fuses loops. Buffers of size less than or -/// equal to `localBufSizeThreshold` are promoted to memory space -/// `fastMemorySpace'. +/// Creates a loop fusion pass which fuses loops according to type of fusion +/// specified in `fusionMode`. Buffers of size less than or equal to +/// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`. std::unique_ptr> createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, - bool maximalFusion = false); + bool maximalFusion = false, + enum FusionMode fusionMode = FusionMode::Greedy); /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -136,7 +136,15 @@ "to fast memory space">, Option<"maximalFusion", "fusion-maximal", "bool", /*default=*/"false", "Enables maximal loop fusion">, - ]; + Option<"affineFusionMode", "affine-fusion-mode", "enum FusionMode", + "mlir::FusionMode::Greedy", "fusion mode to attempt", + "llvm::cl::values(clEnumValN(mlir::FusionMode::Greedy," + " \"greedy\", \"Perform greedy fusion\")," + "clEnumValN( mlir::FusionMode::ProducerConsumer, " + "\"producer\",\"Perform only producer->consumer fusion\")," + "clEnumValN( mlir::FusionMode::Sibling," + "\"sibling\",\"Perform only sibling fusion\"))">, + ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -49,10 +49,11 @@ struct LoopFusion : public AffineLoopFusionBase { LoopFusion() = default; LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, - bool maximalFusion) { + bool maximalFusion, enum FusionMode affineFusionMode) { this->fastMemorySpace = fastMemorySpace; this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; this->maximalFusion = maximalFusion; + this->affineFusionMode = affineFusionMode; } void runOnFunction() override; @@ -62,9 +63,10 @@ std::unique_ptr> mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold, bool maximalFusion) { + uint64_t localBufSizeThreshold, bool maximalFusion, + enum FusionMode affineFusionMode) { return std::make_unique(fastMemorySpace, localBufSizeThreshold, - maximalFusion); + maximalFusion, affineFusionMode); } namespace { @@ -1391,13 +1393,25 @@ worklist.push_back(node.id); } } + /// Run only sibling fusion on the `mdg`. + void runSiblingFusionOnly() { + fuseSiblingNodes(); + eraseUnusedMemRefAllocations(); + } + + /// Run only producer/consumer fusion on the `mdg`. + void runProducerConsumerFusionOnly() { + fuseProducerConsumerNodes( + /*maxSrcUserCount=*/std::numeric_limits::max()); + eraseUnusedMemRefAllocations(); + } // Run the GreedyFusion pass. // *) First pass through the nodes fuses single-use producer nodes into their // unique consumer. // *) Second pass fuses sibling nodes which share no dependence edges. // *) Third pass fuses any remaining producer nodes into their users. - void run() { + void runGreedyFusion() { // TODO: Run this repeatedly until a fixed-point is reached. fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); fuseSiblingNodes(); @@ -1971,5 +1985,11 @@ unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, maximalFusion, computeToleranceThreshold); - fusion.run(); + + if (affineFusionMode == FusionMode::ProducerConsumer) + fusion.runProducerConsumerFusionOnly(); + else if (affineFusionMode == FusionMode::Sibling) + fusion.runSiblingFusionOnly(); + else + fusion.runGreedyFusion(); } diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define TRANSFORMS_PASSDETAIL_H_ #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" namespace mlir { class AffineDialect; diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -1,7 +1,7 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="affine-fusion-mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCERCONSUMER +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal affine-fusion-mode=sibling" -split-input-file | FileCheck %s --check-prefix=SIBLINGMAXIMAL -// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. +// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir @@ -38,17 +38,17 @@ // Test checks the loop structure is preserved after sibling fusion // since the destination loop and source loop trip counts do not // match. -// MAXIMAL-LABEL: func @reduce_add_non_maximal_f32_f32( -// MAXIMAL: %[[cst_0:.*]] = constant 0.000000e+00 : f32 -// MAXIMAL-NEXT: %[[cst_1:.*]] = constant 1.000000e+00 : f32 -// MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 { -// MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { -// MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { -// MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { +// SIBLINGMAXIMAL-LABEL: func @reduce_add_non_maximal_f32_f32( +// SIBLINGMAXIMAL: %[[cst_0:.*]] = constant 0.000000e+00 : f32 +// SIBLINGMAXIMAL-NEXT: %[[cst_1:.*]] = constant 1.000000e+00 : f32 +// SIBLINGMAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 { +// SIBLINGMAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { +// SIBLINGMAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { +// SIBLINGMAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { // Expects fusion of producer into consumer at depth 4 and subsequent removal of // source loop. -// CHECK-LABEL: func @unflatten4d +// PRODUCERCONSUMER-LABEL: func @unflatten4d func @unflatten4d(%arg1: memref<7x8x9x10xf32>) { %m = memref.alloc() : memref<5040xf32> %cf7 = constant 7.0 : f32 @@ -75,18 +75,18 @@ return } -// CHECK: affine.for -// CHECK-NEXT: affine.for -// CHECK-NEXT: affine.for -// CHECK-NEXT: affine.for -// CHECK-NOT: affine.for -// CHECK: return +// PRODUCERCONSUMER: affine.for +// PRODUCERCONSUMER-NEXT: affine.for +// PRODUCERCONSUMER-NEXT: affine.for +// PRODUCERCONSUMER-NEXT: affine.for +// PRODUCERCONSUMER-NOT: affine.for +// PRODUCERCONSUMER: return // ----- // Expects fusion of producer into consumer at depth 2 and subsequent removal of // source loop. -// CHECK-LABEL: func @unflatten2d_with_transpose +// PRODUCERCONSUMER-LABEL: func @unflatten2d_with_transpose func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) { %m = memref.alloc() : memref<56xf32> %cf7 = constant 7.0 : f32 @@ -105,7 +105,7 @@ return } -// CHECK: affine.for -// CHECK-NEXT: affine.for -// CHECK-NOT: affine.for -// CHECK: return \ No newline at end of file +// PRODUCERCONSUMER: affine.for +// PRODUCERCONSUMER-NEXT: affine.for +// PRODUCERCONSUMER-NOT: affine.for +// PRODUCERCONSUMER: return \ No newline at end of file