diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -25,6 +25,10 @@ class AffineForOp; class GreedyRewriteConfig; +/// Fusion mode to attempt. The default mode `Greedy` does both +/// producer-consumer and sibling fusion. +enum FusionMode { Greedy, ProducerConsumer, Sibling }; + //===----------------------------------------------------------------------===// // Passes //===----------------------------------------------------------------------===// @@ -72,13 +76,14 @@ /// Creates a pass to perform common sub expression elimination. std::unique_ptr createCSEPass(); -/// Creates a loop fusion pass which fuses loops. Buffers of size less than or -/// equal to `localBufSizeThreshold` are promoted to memory space -/// `fastMemorySpace'. +/// Creates a loop fusion pass which fuses loops according to type of fusion +/// specified in `fusionMode`. Buffers of size less than or equal to +/// `localBufSizeThreshold` are promoted to memory space `fastMemorySpace`. std::unique_ptr> createLoopFusionPass(unsigned fastMemorySpace = 0, uint64_t localBufSizeThreshold = 0, - bool maximalFusion = false); + bool maximalFusion = false, + enum FusionMode fusionMode = FusionMode::Greedy); /// Creates a loop invariant code motion pass that hoists loop invariant /// instructions out of the loop. diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -136,7 +136,15 @@ "to fast memory space">, Option<"maximalFusion", "fusion-maximal", "bool", /*default=*/"false", "Enables maximal loop fusion">, - ]; + Option<"affineFusionMode", "mode", "enum FusionMode", + "mlir::FusionMode::Greedy", "fusion mode to attempt", + "llvm::cl::values(clEnumValN(mlir::FusionMode::Greedy," + " \"greedy\", \"Perform greedy (both producer-consumer and sibling) fusion\"), " + "clEnumValN( mlir::FusionMode::ProducerConsumer, " + "\"producer\", \"Perform only producer-consumer fusion\"), " + "clEnumValN( mlir::FusionMode::Sibling, " + "\"sibling\", \"Perform only sibling fusion\"))">, + ]; let dependentDialects = ["memref::MemRefDialect"]; } diff --git a/mlir/lib/Transforms/LoopFusion.cpp b/mlir/lib/Transforms/LoopFusion.cpp --- a/mlir/lib/Transforms/LoopFusion.cpp +++ b/mlir/lib/Transforms/LoopFusion.cpp @@ -49,10 +49,11 @@ struct LoopFusion : public AffineLoopFusionBase { LoopFusion() = default; LoopFusion(unsigned fastMemorySpace, uint64_t localBufSizeThresholdBytes, - bool maximalFusion) { + bool maximalFusion, enum FusionMode affineFusionMode) { this->fastMemorySpace = fastMemorySpace; this->localBufSizeThreshold = localBufSizeThresholdBytes / 1024; this->maximalFusion = maximalFusion; + this->affineFusionMode = affineFusionMode; } void runOnFunction() override; @@ -62,9 +63,10 @@ std::unique_ptr> mlir::createLoopFusionPass(unsigned fastMemorySpace, - uint64_t localBufSizeThreshold, bool maximalFusion) { + uint64_t localBufSizeThreshold, bool maximalFusion, + enum FusionMode affineFusionMode) { return std::make_unique(fastMemorySpace, localBufSizeThreshold, - maximalFusion); + maximalFusion, affineFusionMode); } namespace { @@ -1391,13 +1393,25 @@ worklist.push_back(node.id); } } + /// Run only sibling fusion on the `mdg`. + void runSiblingFusionOnly() { + fuseSiblingNodes(); + eraseUnusedMemRefAllocations(); + } + + /// Run only producer/consumer fusion on the `mdg`. + void runProducerConsumerFusionOnly() { + fuseProducerConsumerNodes( + /*maxSrcUserCount=*/std::numeric_limits::max()); + eraseUnusedMemRefAllocations(); + } // Run the GreedyFusion pass. // *) First pass through the nodes fuses single-use producer nodes into their // unique consumer. // *) Second pass fuses sibling nodes which share no dependence edges. // *) Third pass fuses any remaining producer nodes into their users. - void run() { + void runGreedyFusion() { // TODO: Run this repeatedly until a fixed-point is reached. fuseProducerConsumerNodes(/*maxSrcUserCount=*/1); fuseSiblingNodes(); @@ -1971,5 +1985,11 @@ unsigned localBufSizeThresholdBytes = localBufSizeThreshold * 1024; GreedyFusion fusion(&g, localBufSizeThresholdBytes, fastMemorySpaceOpt, maximalFusion, computeToleranceThreshold); - fusion.run(); + + if (affineFusionMode == FusionMode::ProducerConsumer) + fusion.runProducerConsumerFusionOnly(); + else if (affineFusionMode == FusionMode::Sibling) + fusion.runSiblingFusionOnly(); + else + fusion.runGreedyFusion(); } diff --git a/mlir/lib/Transforms/PassDetail.h b/mlir/lib/Transforms/PassDetail.h --- a/mlir/lib/Transforms/PassDetail.h +++ b/mlir/lib/Transforms/PassDetail.h @@ -10,6 +10,7 @@ #define TRANSFORMS_PASSDETAIL_H_ #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/Passes.h" namespace mlir { class AffineDialect; diff --git a/mlir/test/Transforms/loop-fusion-4.mlir b/mlir/test/Transforms/loop-fusion-4.mlir --- a/mlir/test/Transforms/loop-fusion-4.mlir +++ b/mlir/test/Transforms/loop-fusion-4.mlir @@ -1,54 +1,13 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion -split-input-file | FileCheck %s -// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal" -split-input-file | FileCheck %s --check-prefix=MAXIMAL +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="mode=producer" -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER +// RUN: mlir-opt -allow-unregistered-dialect %s -affine-loop-fusion="fusion-maximal mode=sibling" -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL -// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. +// Part I of fusion tests in mlir/test/Transforms/loop-fusion.mlir. // Part II of fusion tests in mlir/test/Transforms/loop-fusion-2.mlir // Part III of fusion tests in mlir/test/Transforms/loop-fusion-3.mlir -// ----- - -func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) { - %cst_0 = constant 0.000000e+00 : f32 - %cst_1 = constant 1.000000e+00 : f32 - affine.for %arg3 = 0 to 1 { - affine.for %arg4 = 0 to 64 { - %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 { - %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1> - %5 = addf %prevAccum, %4 : f32 - affine.yield %5 : f32 - } - %accum_dbl = addf %accum, %accum : f32 - affine.store %accum_dbl, %arg1[%arg3, %arg4] : memref<1x64xf32, 1> - } - } - affine.for %arg3 = 0 to 1 { - affine.for %arg4 = 0 to 64 { - // Following loop trip count does not match the corresponding source trip count. - %accum = affine.for %arg5 = 0 to 32 iter_args (%prevAccum = %cst_1) -> f32 { - %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1> - %5 = mulf %prevAccum, %4 : f32 - affine.yield %5 : f32 - } - %accum_sqr = mulf %accum, %accum : f32 - affine.store %accum_sqr, %arg2[%arg3, %arg4] : memref<1x64xf32, 1> - } - } - return -} -// Test checks the loop structure is preserved after sibling fusion -// since the destination loop and source loop trip counts do not -// match. -// MAXIMAL-LABEL: func @reduce_add_non_maximal_f32_f32( -// MAXIMAL: %[[cst_0:.*]] = constant 0.000000e+00 : f32 -// MAXIMAL-NEXT: %[[cst_1:.*]] = constant 1.000000e+00 : f32 -// MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 { -// MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { -// MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { -// MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { - // Expects fusion of producer into consumer at depth 4 and subsequent removal of // source loop. -// CHECK-LABEL: func @unflatten4d +// PRODUCER-CONSUMER-LABEL: func @unflatten4d func @unflatten4d(%arg1: memref<7x8x9x10xf32>) { %m = memref.alloc() : memref<5040xf32> %cf7 = constant 7.0 : f32 @@ -75,18 +34,18 @@ return } -// CHECK: affine.for -// CHECK-NEXT: affine.for -// CHECK-NEXT: affine.for -// CHECK-NEXT: affine.for -// CHECK-NOT: affine.for -// CHECK: return +// PRODUCER-CONSUMER: affine.for +// PRODUCER-CONSUMER-NEXT: affine.for +// PRODUCER-CONSUMER-NEXT: affine.for +// PRODUCER-CONSUMER-NEXT: affine.for +// PRODUCER-CONSUMER-NOT: affine.for +// PRODUCER-CONSUMER: return // ----- // Expects fusion of producer into consumer at depth 2 and subsequent removal of // source loop. -// CHECK-LABEL: func @unflatten2d_with_transpose +// PRODUCER-CONSUMER-LABEL: func @unflatten2d_with_transpose func @unflatten2d_with_transpose(%arg1: memref<8x7xf32>) { %m = memref.alloc() : memref<56xf32> %cf7 = constant 7.0 : f32 @@ -105,7 +64,48 @@ return } -// CHECK: affine.for -// CHECK-NEXT: affine.for -// CHECK-NOT: affine.for -// CHECK: return \ No newline at end of file +// PRODUCER-CONSUMER: affine.for +// PRODUCER-CONSUMER-NEXT: affine.for +// PRODUCER-CONSUMER-NOT: affine.for +// PRODUCER-CONSUMER: return + +// ----- + +// SIBLING-MAXIMAL-LABEL: func @reduce_add_non_maximal_f32_f32( +func @reduce_add_non_maximal_f32_f32(%arg0: memref<64x64xf32, 1>, %arg1 : memref<1x64xf32, 1>, %arg2 : memref<1x64xf32, 1>) { + %cst_0 = constant 0.000000e+00 : f32 + %cst_1 = constant 1.000000e+00 : f32 + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 64 { + %accum = affine.for %arg5 = 0 to 64 iter_args (%prevAccum = %cst_0) -> f32 { + %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1> + %5 = addf %prevAccum, %4 : f32 + affine.yield %5 : f32 + } + %accum_dbl = addf %accum, %accum : f32 + affine.store %accum_dbl, %arg1[%arg3, %arg4] : memref<1x64xf32, 1> + } + } + affine.for %arg3 = 0 to 1 { + affine.for %arg4 = 0 to 64 { + // Following loop trip count does not match the corresponding source trip count. + %accum = affine.for %arg5 = 0 to 32 iter_args (%prevAccum = %cst_1) -> f32 { + %4 = affine.load %arg0[%arg5, %arg4] : memref<64x64xf32, 1> + %5 = mulf %prevAccum, %4 : f32 + affine.yield %5 : f32 + } + %accum_sqr = mulf %accum, %accum : f32 + affine.store %accum_sqr, %arg2[%arg3, %arg4] : memref<1x64xf32, 1> + } + } + return +} +// Test checks the loop structure is preserved after sibling fusion +// since the destination loop and source loop trip counts do not +// match. +// SIBLING-MAXIMAL: %[[cst_0:.*]] = constant 0.000000e+00 : f32 +// SIBLING-MAXIMAL-NEXT: %[[cst_1:.*]] = constant 1.000000e+00 : f32 +// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_0:.*]]= 0 to 1 { +// SIBLING-MAXIMAL-NEXT: affine.for %[[idx_1:.*]] = 0 to 64 { +// SIBLING-MAXIMAL-NEXT: %[[result_1:.*]] = affine.for %[[idx_2:.*]] = 0 to 32 iter_args(%[[iter_0:.*]] = %[[cst_1]]) -> (f32) { +// SIBLING-MAXIMAL-NEXT: %[[result_0:.*]] = affine.for %[[idx_3:.*]] = 0 to 64 iter_args(%[[iter_1:.*]] = %[[cst_0]]) -> (f32) { \ No newline at end of file