diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse -split-input-file %s | FileCheck %s func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -271,18 +271,12 @@ // CHECK-DAG: %[[ST_ARG0:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] // CHECK-DAG: %[[ST_ARG1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] // CHECK-DAG: %[[ST_ARG2:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[LHS:.+]] = linalg.matmul +// CHECK: %[[MATMUL:.+]] = linalg.matmul // CHECK-SAME: ins(%[[ST_ARG0]], %[[ST_ARG1]] : // CHECK-SAME: outs(%[[ST_ARG2]] : -// CHECK-DAG: %[[ST_ARG0_1:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], 0] -// CHECK-DAG: %[[ST_ARG1_1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[IV1]]] -// CHECK-DAG: %[[ST_ARG2_1:.+]] = tensor.extract_slice %[[ARG2]][%[[IV0]], %[[IV1]]] -// CHECK: %[[RHS:.+]] = linalg.matmul -// CHECK-SAME: ins(%[[ST_ARG0_1]], %[[ST_ARG1_1]] : -// CHECK-SAME: outs(%[[ST_ARG2_1]] : // CHECK: %[[ST_ARG6:.+]] = tensor.extract_slice %[[ARG6]][%[[IV0]], %[[IV1]]] // CHECK: %[[ST_RESULT:.+]] = linalg.generic -// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : +// CHECK-SAME: ins(%[[MATMUL]], %[[MATMUL]] : // CHECK-SAME: outs(%[[ST_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[ST_RESULT]] // CHECK-SAME: into %[[ARG6]][%[[IV0]], %[[IV1]]] @@ -401,3 +395,69 @@ // CHECK-SAME: outs(%[[SLICE_ARG6]] : // CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[TILE_GEMM3]] into %[[ARG8]][%[[IV]], 0] [%[[TILE_M]], %[[N3]]] // CHECK: scf.yield %[[UPDATE]] + +// ----- + +func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %cst_0 = arith.constant 0xFF800000 : f32 + %0 = tensor.empty() : tensor<30xf32> + %1 = linalg.fill ins(%cst_0 : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor<30x3xf32>) outs(%1 : tensor<30xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %8 = arith.maxf %arg2, %arg1 : f32 + linalg.yield %8 : f32 + } -> tensor<30xf32> + %3 = tensor.empty() : tensor<30x3xf32> + %4 = linalg.fill ins(%cst : f32) outs(%0 : tensor<30xf32>) -> tensor<30xf32> + %5:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0, %2 : tensor<30x3xf32>, tensor<30xf32>) outs(%4, %3 : tensor<30xf32>, tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32, %arg4: f32): + %8 = arith.subf %arg1, %arg2 : f32 + %9 = math.exp %8 : f32 + %10 = arith.addf %arg3, %9 : f32 + linalg.yield %10, %9 : f32, f32 + } -> (tensor<30xf32>, tensor<30x3xf32>) + %6 = linalg.generic { + __internal_linalg_transform__ = "reduction_sequence_fusion", + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%5#1, %5#0 : tensor<30x3xf32>, tensor<30xf32>) outs(%3 : tensor<30x3xf32>) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor<30x3xf32> + return %6 : tensor<30x3xf32> +} +// CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>) +// CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<30xf32> +// CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32> +// CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV:[a-zA-Z0-9]+]] +// CHECK-SAME: iter_args(%[[ITERARG0:[a-zA-Z0-9]+]] = %[[INIT1]]) +// CHECK-DAG: %[[ARG0_SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV]], 0] +// CHECK-DAG: %[[INIT0_SLICE:.+]] = tensor.extract_slice %[[INIT0]][%[[IV]]] +// CHECK: %[[FILL0:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[GENERIC0:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]] : +// CHECK-SAME: outs(%[[FILL0]] : +// CHECK: %[[FILL1:.+]] = linalg.fill +// CHECK-SAME: outs(%[[INIT0_SLICE]] : +// CHECK: %[[INIT1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK: %[[GENERIC1:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[GENERIC0]] : +// CHECK-SAME: outs(%[[FILL1]], %[[INIT1_SLICE]] : +// CHECK: %[[ITERARG0_SLICE:.+]] = tensor.extract_slice %[[ITERARG0]][%[[IV]], 0] +// CHECK: %[[GENERIC2:.+]] = linalg.generic +// CHECK-SAME: ins(%[[GENERIC1]]#1, %[[GENERIC1]]#0 : +// CHECK-SAME: outs(%[[ITERARG0_SLICE]] : +// CHECK-DAG: %[[INSERTSLICE:.+]] = tensor.insert_slice %[[GENERIC2]] into %[[ITERARG0]][%[[IV]], 0] +// CHECK: scf.yield %[[INSERTSLICE]] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -368,6 +368,9 @@ // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M // dimension. addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); + // 6. Fusion of back-to-back-reduction ops + addPatternForTileAndFuse(context, patterns, "reduction_sequence_fusion", + {10}); return; } if (testLoweringToScalar) {