diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -332,16 +332,22 @@ linalgOps.push_back(op); }); - Aliases aliases; - LinalgDependenceGraph G(aliases, linalgOps); + // TODO(pifon, ntv): LinalgDependenceGraph should be able to update itself. + // The current naive and expensive reconstruction of the graph should be + // removed. for (auto *op : llvm::reverse(linalgOps)) { - for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs(); - consumerIdx < e; ++consumerIdx) { - if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder)) - eraseSet.insert(fusionInfo->originalProducer.getOperation()); + for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalgOps); + if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { + auto *originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto *originalOpInLinalgOpsVector = + std::find(linalgOps.begin(), linalgOps.end(), originalOp); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } } } - // The `fuseProducerOf` function performs structural checks and in particular // that no covering read or write exist between the consumer and the producer. // As a consequence, the only fusions that may occur preserve subsequent diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s --dump-input-on-failure func @f1(%A: memref, %B: memref, @@ -262,13 +262,23 @@ // CHECK-DAG: %[[B_1:.*]] = dim %[[B]], 1 : memref // CHECK-DAG: %[[D_0:.*]] = dim %[[D]], 0 : memref // CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref -// Don't fuse C due to false dependence, note that this is too conservative though. -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) -// CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { -// CHECK: loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { -// CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { -// CHECK: linalg.matmul -// CHECK: linalg.matmul +// CHECK: loop.for %[[I:.*]] = %{{.*}} to %[[D_0]] step %{{.*}} { +// CHECK: loop.for %[[J:.*]] = %{{.*}} to %[[B_1]] step %{{.*}} { +// CHECK: loop.for %[[K:.*]] = %{{.*}} to %[[D_1]] step %{{.*}} { +// CHECK-DAG: %[[D_IK:.*]] = std.subview %[[D]][%[[I]], %[[K]]] +// CHECK-DAG: %[[B_KJ:.*]] = std.subview %[[B]][%[[K]], %[[J]]] +// CHECK-DAG: %[[E_IJ:.*]] = std.subview %[[E]][%[[I]], %[[J]]] +// CHECK: dim +// CHECK-DAG: %[[C_I0:.*]] = std.subview %[[C]][%[[I]], %{{.*}}] +// CHECK-DAG: %[[B_0K:.*]] = std.subview %[[B]][%{{.*}}, %[[K]]] +// CHECK-DAG: %[[D_IK_:.*]] = std.subview %[[D]][%[[I]], %[[K]]] +// CHECK: dim +// CHECK-DAG: %[[A_I0:.*]] = std.subview %[[A]][%[[I]], %{{.*}}] +// CHECK-DAG: %[[B_00:.*]] = std.subview %[[B]][%{{.*}}, %{{.*}}] +// CHECK-DAG: %[[C_I0_:.*]] = std.subview %[[C]][%[[I]], %{{.*}}] +// CHECK: linalg.matmul(%[[A_I0]], %[[B_00]], %[[C_I0_]]) +// CHECK: linalg.matmul(%[[C_I0]], %[[B_0K]], %[[D_IK_]]) +// CHECK: linalg.matmul(%[[D_IK]], %[[B_KJ]], %[[E_IJ]]) // ----- @@ -659,3 +669,76 @@ // CHECK: addf // CHECK: linalg.indexed_generic // CHECK: index_cast + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +func @fusion_of_three(%arg0: memref<100x10xf32>, + %arg1: memref<100xf32>, + %arg2: memref<100x10xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = alloc() {temp = true} : memref<100x10xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel"] + } %arg1, %0 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + linalg.yield %arg3 : f32 + }: memref<100xf32>, memref<100x10xf32> + %1 = alloc() {temp = true} : memref<100x10xf32> + linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1, #map1], + iterator_types = ["parallel", "parallel"] + } %arg0, %0, %1 { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = subf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + dealloc %0 : memref<100x10xf32> + %2 = dim %1, 0 : memref<100x10xf32> + %3 = dim %1, 1 : memref<100x10xf32> + %4 = dim %arg2, 0 : memref<100x10xf32> + %5 = dim %arg2, 1 : memref<100x10xf32> + loop.for %i = %c0 to %2 step %c1 { + loop.for %j = %c0 to %3 step %c1 { + %6 = std.subview %1[%i, %j][%c1, %c1][%c1, %c1] : + memref<100x10xf32> to memref + %7 = std.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] : + memref<100x10xf32> to memref + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1], + iterator_types = ["parallel", "parallel"] + } %6, %7 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %8 = exp %arg3 : f32 + linalg.yield %8 : f32 + }: memref, + memref + } + } + dealloc %1 : memref<100x10xf32> + return +} +// CHECK-LABEL: func @fusion +// CHECK-NOT: linalg.generic +// CHECK: loop.for +// CHECK: loop.for +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: linalg.yield +// CHECK: linalg.generic +// CHECK: subf +// CHECK: linalg.yield +// CHECK: linalg.generic +// CHECK: exp +// CHECK: linalg.yield