diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -334,16 +334,20 @@ linalgOps.push_back(op); }); - Aliases aliases; - LinalgDependenceGraph G(aliases, linalgOps); for (auto *op : llvm::reverse(linalgOps)) { - for (unsigned consumerIdx = 0, e = LinalgOp(op).getNumInputs(); - consumerIdx < e; ++consumerIdx) { - if (auto fusionInfo = fuseProducerOf(b, op, consumerIdx, G, &folder)) - eraseSet.insert(fusionInfo->originalProducer.getOperation()); + for (unsigned id = 0, e = LinalgOp(op).getNumInputs(); id < e; ++id) { + linalg::Aliases aliases; + linalg::LinalgDependenceGraph graph(aliases, linalgOps); + if (auto info = fuseProducerOf(b, op, id, graph, &folder)) { + auto originalOp = info->originalProducer.getOperation(); + eraseSet.insert(originalOp); + auto originalOpInLinalgOpsVector = + std::find_if(linalgOps.begin(), linalgOps.end(), + [&](const Operation *op) { return op == originalOp; }); + *originalOpInLinalgOpsVector = info->fusedProducer.getOperation(); + } } } - // The `fuseProducerOf` function performs structural checks and in particular // that no covering read or write exist between the consumer and the producer. // As a consequence, the only fusions that may occur preserve subsequent diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -1,15 +1,11 @@ -// RUN: mlir-opt %s -linalg-fusion | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion -split-input-file | FileCheck %s -#map0 = affine_map<(d0) -> (d0 + 2)> -#map1 = affine_map<(d0) -> (d0 + 4)> -#map2 = affine_map<(d0) -> (d0 + 3)> -#map3 = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> -#map4 = affine_map<(d0) -> (d0)> -#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> -#map6 = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> - -func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +func @f1(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index @@ -17,15 +13,27 @@ %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref %2 = dim %B, 1 : memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref %c1 = constant 1 : index loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { - %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %C[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -40,23 +48,43 @@ // CHECK: loop.for // CHECK: linalg.matmul -func @f2(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> +func @f2(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref %0 = dim %C, 0 : memref %1 = dim %C, 1 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { - %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -73,23 +101,42 @@ // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f3(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +func @f3(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref %0 = dim %D, 0 : memref %1 = dim %D, 1 : memref %2 = dim %C, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { - %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %7 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %7 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -106,24 +153,46 @@ // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f4(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +func @f4(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : memref, memref, memref - linalg.matmul(%A, %B, %D) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref + linalg.matmul(%A, %B, %D) : + memref, + memref, + memref %0 = dim %C, 0 : memref %1 = dim %C, 1 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { - %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -142,7 +211,15 @@ // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f5(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +// CHECK-DAG: #[[strided2D:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)> +func @f5(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index @@ -151,15 +228,30 @@ %0 = dim %B, 1 : memref %1 = dim %D, 0 : memref %2 = dim %D, 1 : memref - linalg.matmul(%A, %B, %C) : memref, memref, memref - linalg.matmul(%C, %B, %D) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref + linalg.matmul(%C, %B, %D) : + memref, + memref, + memref loop.for %arg5 = %c0 to %1 step %c2 { loop.for %arg6 = %c0 to %0 step %c3 { loop.for %arg7 = %c0 to %2 step %c4 { - %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %5 = std.subview %D[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %7 = std.subview %B[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -170,23 +262,39 @@ // CHECK-DAG: %[[B_1:.*]] = dim %[[B]], 1 : memref // CHECK-DAG: %[[D_0:.*]] = dim %[[D]], 0 : memref // CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref -// Don't fuse C due to false dependence, note that this is too conservative though. -// CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[B_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul +// CHECK: linalg.matmul + +// ----- -func @f6(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +#map0 = affine_map<(d0) -> (d0 + 2)> +#map1 = affine_map<(d0) -> (d0 + 4)> +#map2 = affine_map<(d0) -> (d0 + 3)> + +func @f6(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index %0 = dim %C, 1 : memref - linalg.matmul(%A, %B, %C) : memref, memref, memref - linalg.matmul(%A, %C, %E) : memref, memref, memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref + linalg.matmul(%A, %C, %E) : + memref, + memref, + memref %1 = dim %C, 0 : memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %1 step %c2 { @@ -194,11 +302,20 @@ loop.for %arg7 = %c0 to %0 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref + %5 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref %6 = affine.apply #map2(%arg6) - %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -216,7 +333,14 @@ // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul -func @f7(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +func @f7(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index @@ -227,25 +351,49 @@ %2 = dim %C, 1 : memref %3 = dim %C, 0 : memref %4 = dim %D, 1 : memref - linalg.matmul(%A, %C, %E) : memref, memref, memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %C, %E) : + memref, + memref, + memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { - %7 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %9 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%7, %9, %10) : memref, memref, memref + %7 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %9 = std.subview %C[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%7, %9, %10) : + memref, + memref, + memref } } } loop.for %arg5 = %c0 to %3 step %c2 { loop.for %arg6 = %c0 to %4 step %c3 { loop.for %arg7 = %c0 to %2 step %c4 { - %7 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref - %9 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%7, %9, %10) : memref, memref, memref + %7 = std.subview %C[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref + %9 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %10 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%7, %9, %10) : + memref, + memref, + memref } } } @@ -270,7 +418,18 @@ // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul -func @f8(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { +// ----- + +#map0 = affine_map<(d0) -> (d0 + 2)> +#map1 = affine_map<(d0) -> (d0 + 4)> +#map2 = affine_map<(d0) -> (d0 + 3)> + +func @f8(%A: memref, + %B: memref, + %C: memref, + %D: memref, + %E: memref + ) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index @@ -278,19 +437,34 @@ %c2 = constant 2 : index %0 = dim %A, 0 : memref %1 = dim %A, 1 : memref - linalg.matmul(%A, %C, %D) : memref, memref, memref - linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %C, %D) : + memref, + memref, + memref + linalg.matmul(%A, %B, %C) : + memref, + memref, + memref %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : memref to memref + %5 = std.subview %A[%arg5, %arg7][%c2, %c4][%c1, %c1] : + memref to + memref %6 = affine.apply #map2(%arg6) - %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : memref to memref - %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : memref to memref - linalg.matmul(%5, %7, %8) : memref, memref, memref + %7 = std.subview %D[%arg7, %arg6][%c4, %c3][%c1, %c1] : + memref to + memref + %8 = std.subview %E[%arg5, %arg6][%c2, %c3][%c1, %c1] : + memref to + memref + linalg.matmul(%5, %7, %8) : + memref, + memref, + memref } } } @@ -306,6 +480,8 @@ // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul +// ----- + #id_2d = affine_map<(i, j) -> (i, j)> #pointwise_2d_trait = { args_in = 2, @@ -313,7 +489,10 @@ indexing_maps = [#id_2d, #id_2d, #id_2d], iterator_types = ["parallel", "parallel"] } -func @pointwise(%A: memref, %B: memref, %C: memref, %D: memref) { +func @pointwise(%A: memref, + %B: memref, + %C: memref, + %D: memref) { %c1 = constant 1 : index %c0 = constant 0 : index %c3 = constant 3 : index @@ -322,19 +501,29 @@ ^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %E, %arg5 : f32 linalg.yield %2 : f32 - }: memref, memref, memref + }: memref, + memref, + memref %0 = dim %B, 0 : memref %1 = dim %B, 1 : memref loop.for %arg4 = %c0 to %0 step %c2 { loop.for %arg5 = %c0 to %1 step %c3 { - %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref + %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref + %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref + %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref linalg.generic #pointwise_2d_trait %4, %5, %6 { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: memref, memref, memref + }: memref, + memref, + memref } } return @@ -348,6 +537,15 @@ // CHECK: linalg.generic // CHECK: mulf +// ----- + +#id_2d = affine_map<(i, j) -> (i, j)> +#pointwise_2d_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [#id_2d, #id_2d, #id_2d], + iterator_types = ["parallel", "parallel"] +} func @pointwise_no_view(%M: index, %N: index) { %c1 = constant 1 : index %c0 = constant 0 : index @@ -362,19 +560,29 @@ ^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %e, %arg5 : f32 linalg.yield %2 : f32 - }: memref, memref, memref + }: memref, + memref, + memref %0 = dim %B, 0 : memref %1 = dim %B, 1 : memref loop.for %arg4 = %c0 to %0 step %c2 { loop.for %arg5 = %c0 to %1 step %c3 { - %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref - %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : memref to memref + %4 = std.subview %B[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref + %5 = std.subview %C[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref + %6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] : + memref to + memref linalg.generic #pointwise_2d_trait %4, %5, %6 { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: memref, memref, memref + }: memref, + memref, + memref } } return @@ -388,6 +596,17 @@ // CHECK: linalg.generic // CHECK: mulf +// ----- + +#map5 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +#map6 = affine_map<(d0, d1) -> (d0, d1)> +#id_2d = affine_map<(i, j) -> (i, j)> +#pointwise_2d_trait = { + args_in = 2, + args_out = 1, + indexing_maps = [#id_2d, #id_2d, #id_2d], + iterator_types = ["parallel", "parallel"] +} func @indexed_generic_test(%A: memref, %B: memref, %C: memref, @@ -439,3 +658,76 @@ // CHECK: addf // CHECK: linalg.indexed_generic // CHECK: index_cast + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0)> +#map1 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> + +func @fusion_of_three(%arg0: memref<100x10xf32>, + %arg1: memref<100xf32>, + %arg2: memref<100x10xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = alloc() {temp = true} : memref<100x10xf32> + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "parallel"] + } %arg1, %0 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + linalg.yield %arg3 : f32 + }: memref<100xf32>, memref<100x10xf32> + %1 = alloc() {temp = true} : memref<100x10xf32> + linalg.generic { + args_in = 2 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1, #map1], + iterator_types = ["parallel", "parallel"] + } %arg0, %0, %1 { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors + %2 = subf %arg3, %arg4 : f32 + linalg.yield %2 : f32 + }: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32> + dealloc %0 : memref<100x10xf32> + %2 = dim %1, 0 : memref<100x10xf32> + %3 = dim %1, 1 : memref<100x10xf32> + %4 = dim %arg2, 0 : memref<100x10xf32> + %5 = dim %arg2, 1 : memref<100x10xf32> + loop.for %i = %c0 to %2 step %c1 { + loop.for %j = %c0 to %3 step %c1 { + %6 = std.subview %1[%i, %j][%c1, %c1][%c1, %c1] : + memref<100x10xf32> to memref + %7 = std.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] : + memref<100x10xf32> to memref + linalg.generic { + args_in = 1 : i64, + args_out = 1 : i64, + indexing_maps = [#map1, #map1], + iterator_types = ["parallel", "parallel"] + } %6, %7 { + ^bb0(%arg3: f32, %arg4: f32): // no predecessors + %8 = exp %arg3 : f32 + linalg.yield %8 : f32 + }: memref, + memref + } + } + dealloc %1 : memref<100x10xf32> + return +} +// CHECK-LABEL: func @fusion +// CHECK-NOT: linalg.generic +// CHECK: loop.for +// CHECK: loop.for +// CHECK-NOT: loop.for +// CHECK: linalg.generic +// CHECK: linalg.yield +// CHECK: linalg.generic +// CHECK: subf +// CHECK: linalg.yield +// CHECK: linalg.generic +// CHECK: exp +// CHECK: linalg.yield