diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -454,19 +454,24 @@ } }; + // Perform tiling and fusion in two steps. We need to respect the loop + // interchange here; filter parellel dimensions based on their order *after* + // permutation but pass in the original configuration *before* permuation, + // given the tiling and interchange happen together. + SmallVector outerTileSizes(tileSizes.size(), 0); + SmallVector innerTileSizes(tileSizes.size(), 0); + for (int64_t i : tileInterchange.take_front(split)) + outerTileSizes[i] = tileSizes[i]; + for (int64_t i : tileInterchange.drop_front(split)) + innerTileSizes[i] = tileSizes[i]; + // Tile the outer parallel loops and fuse the output operands. - SmallVector outerTileSizes; - outerTileSizes.append(tileSizes.begin(), tileSizes.begin() + split); - outerTileSizes.append(tileSizes.size() - split, 0); if (failed(tileLoopNest.tileRootOp(b, outerTileSizes, tileInterchange, tileDistribution))) return failure(); fuseProducersGreedily(tileLoopNest.getRootOp().getOutputOperands()); // Tile the remaining loops and fuse the input operands. - SmallVector innerTileSizes; - innerTileSizes.append(split, 0); - innerTileSizes.append(tileSizes.begin() + split, tileSizes.end()); if (failed(tileLoopNest.tileRootOp(b, innerTileSizes, tileInterchange, tileDistribution))) return failure(); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -1,5 +1,6 @@ // RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.matmul fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=MATMUL %s -// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=1,0,2 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC1 %s +// RUN: mlir-opt %s -test-linalg-codegen-strategy="anchor-op=linalg.generic fuse tile-sizes=5,4,7 tile-interchange=0,2,1 run-enable-pass=false" -cse -split-input-file | FileCheck --check-prefix=GENERIC2 %s // MATMUL-DAG: #[[MAP0:.*]] = affine_map<(d0) -> (-d0 + 24, 5)> // MATMUL-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (-d0 + 12, 7)> @@ -249,28 +250,28 @@ #map0 = affine_map<(d0, d1) -> (d0, d1)> #map1 = affine_map<(d0, d1) -> (d0)> -// GENERIC: fuse_outermost_reduction -// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> -// GENERIC-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32> +// GENERIC1: fuse_outermost_reduction +//GENERIC1-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> +//GENERIC1-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32> func.func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32> // Cannot fuse the output fill since the reduction loop is the outermost loop. - // GENERIC: %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG1]] + // GENERIC1: %[[T0:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[ARG1]] %1 = linalg.fill ins(%cst : f32) outs(%arg1 : tensor<10xf32>) -> tensor<10xf32> - // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] - // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + // GENERIC1: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] + // GENERIC1: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] // MATMUL the input fill has been fused. - // GENERIC: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] - // GENERIC-SAME: %[[IV1]], %[[IV0]] - // GENERIC: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] - // GENERIC: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] - // GENERIC-SAME: %[[IV1]] - // GENERIC: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] + // GENERIC1: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + //GENERIC1-SAME: %[[IV1]], %[[IV0]] + // GENERIC1: %[[T2:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T1]] + // GENERIC1: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] + //GENERIC1-SAME: %[[IV1]] + // GENERIC1: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) { ^bb0(%arg2: f32, %arg3: f32): %3 = arith.addf %arg2, %arg3 : f32 @@ -281,39 +282,39 @@ // ----- -// GENERIC-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> -// GENERIC-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (-d0 - d1 + 17, 8)> -// GENERIC-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (-d1 - d2 + 17, d0)> +// GENERIC1-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// GENERIC1-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (-d0 - d1 + 17, 8)> +// GENERIC1-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (-d1 - d2 + 17, d0)> #map0 = affine_map<(d0, d1) -> (d0, d0 + d1)> #map1 = affine_map<(d0, d1) -> (d0, d1)> -// GENERIC: fuse_non_rectangular -// GENERIC-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> +// GENERIC1: fuse_non_rectangular +//GENERIC1-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> func.func @fuse_non_rectangular(%arg0: tensor<10x17xf32>, %arg1: tensor<10x8xf32>) -> tensor<10x8xf32> { - // GENERIC-DAG: %[[C0:.*]] = arith.constant 0 : index - // GENERIC-DAG: %[[C4:.*]] = arith.constant 4 : index - // GENERIC-DAG: %[[C5:.*]] = arith.constant 5 : index - // GENERIC-DAG: %[[C8:.*]] = arith.constant 8 : index - // GENERIC-DAG: %[[C10:.*]] = arith.constant 10 : index + // GENERIC1-DAG: %[[C0:.*]] = arith.constant 0 : index + // GENERIC1-DAG: %[[C4:.*]] = arith.constant 4 : index + // GENERIC1-DAG: %[[C5:.*]] = arith.constant 5 : index + // GENERIC1-DAG: %[[C8:.*]] = arith.constant 8 : index + // GENERIC1-DAG: %[[C10:.*]] = arith.constant 10 : index %cst = arith.constant 0.000000e+00 : f32 %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<10x17xf32>) -> tensor<10x17xf32> - // GENERIC: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]] - // GENERIC: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]] + // GENERIC1: scf.for %[[IV0:[0-9a-zA-Z]*]] = %[[C0]] to %[[C8]] step %[[C4]] + // GENERIC1: scf.for %[[IV1:[0-9a-zA-Z]*]] = %[[C0]] to %[[C10]] step %[[C5]] // Compute producer on a hyper rectangular bounding box. Along the second dimenson, // the offset is set to the sum of the induction variables, and the upper bound // to either 8 (tile size) or 17 (sum of max indices (9+7) then + 1) minus the // induction variables. - // GENERIC-DAG: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]] - // GENERIC-DAG: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]] - // GENERIC-DAG: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]] - // GENERIC: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] - // GENERIC-SAME: %[[IV1]], %[[SUM]] - // GENERIC-SAME: , %[[UB1]] - // GENERIC: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] + // GENERIC1-DAG: %[[SUM:.*]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV0]] + // GENERIC1-DAG: %[[TS1:.*]] = affine.min #[[MAP1]](%[[IV1]], %[[IV0]] + // GENERIC1-DAG: %[[UB1:.*]] = affine.min #[[MAP2]](%[[TS1]], %[[IV1]], %[[IV0]] + // GENERIC1: %[[T0:.*]] = tensor.extract_slice %[[ARG0]] + //GENERIC1-SAME: %[[IV1]], %[[SUM]] + //GENERIC1-SAME: , %[[UB1]] + // GENERIC1: %[[T1:.*]] = linalg.fill ins(%{{.*}}{{.*}}outs(%[[T0]] %1 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "parallel"]} ins(%0 : tensor<10x17xf32>) outs(%arg1 : tensor<10x8xf32>) { ^bb0(%arg2: f32, %arg3: f32): %2 = arith.addf %arg2, %arg3 : f32 @@ -321,3 +322,40 @@ } -> tensor<10x8xf32> func.return %1 : tensor<10x8xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2)> + +func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> { + %five = arith.constant 5.0 : f32 + %init = linalg.init_tensor [12, 25] : tensor<12x25xf32> + %fill = linalg.fill ins(%five : f32) outs(%init : tensor<12x25xf32>) -> tensor<12x25xf32> + %0 = linalg.generic { + indexing_maps = [#map0, #map1], + iterator_types = ["parallel", "reduction", "parallel"] + } ins(%input : tensor<12x7x25xf32>) outs(%fill : tensor<12x25xf32>) { + ^bb0(%arg0: f32, %arg1: f32): + %2 = arith.addf %arg0, %arg1 : f32 + linalg.yield %2 : f32 + } -> tensor<12x25xf32> + func.return %0 : tensor<12x25xf32> +} + +// GENERIC2-LABEL: func @interchange_reduction +// GENERIC2-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>) + +// GENERIC2-DAG: %[[C4:.+]] = arith.constant 4 : index +// GENERIC2-DAG: %[[C5:.+]] = arith.constant 5 : index +// GENERIC2-DAG: %[[C7:.+]] = arith.constant 7 : index + +// GENERIC2: %[[INIT:.+]] = linalg.init_tensor [12, 25] : tensor<12x25xf32> +// GENERIC2: scf.for %[[IV0:.+]] = %{{.+}} to %{{.+}} step %[[C5]] iter_args(%[[FOR_ARG0:.+]] = %[[INIT]]) +// GENERIC2: scf.for %[[IV1:.+]] = %{{.+}} to %{{.+}} step %[[C7]] iter_args(%[[FOR_ARG1:.+]] = %[[FOR_ARG0]]) +// GENERIC2: %[[OUT_SLICE0:.+]] = tensor.extract_slice %[[FOR_ARG1]][%[[IV0]], %[[IV1]]] +// GENERIC2: %[[FILL:.+]] = linalg.fill {{.+}} outs(%[[OUT_SLICE0]] : tensor) +// GENERIC2: scf.for %[[IV2:.+]] = %{{.+}} to %{{.+}} step %[[C4]] iter_args(%[[FOR_ARG2:.+]] = %[[FILL]]) +// GENERIC2: %[[IN_SLICE:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[IV2]], %[[IV1]]] +// GENERIC2: %[[OUT_SLICE2:.+]] = tensor.extract_slice %[[FOR_ARG2]][0, 0] +// GENERIC2: %14 = linalg.generic {{.+}} ins(%[[IN_SLICE]] : tensor) outs(%[[OUT_SLICE2]] : tensor)