diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -225,9 +225,9 @@ LogicalResult tileRootOp(OpBuilder &b, ArrayRef tileSizes, ArrayRef tileInterchange); - /// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the - /// fused producer of fails if fusion is not possible. - FailureOr fuseProducer(OpBuilder &b, OpOperand *rootOpOperand); + /// Fuse the producer of `consumerOpOperand` into the tile loop nest. Returns + /// the fused producer or fails if fusion is not possible. + FailureOr fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand); /// Returns the replacement results for the original untiled root operation. ValueRange getRootOpReplacementResults(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -317,8 +317,11 @@ FailureOr TileLoopNest::fuseProducer(OpBuilder &b, OpOperand *consumerOpOperand) { - assert(tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) != 0 && - "expect the operand owner is the root operation or a fused producer"); + // Check if the consumer has been tiled before. For example, it may not have + // been tiled if the outermost tile loop is a reduction loop. + if (tiledRootAndFusedOpsLoops.count(consumerOpOperand->getOwner()) == 0) + return failure(); + assert(this->isValid() && "expect the tile loop nest to satisfy all invariants"); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir --- a/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-on-tensors.mlir @@ -232,6 +232,41 @@ // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> + +// CHECK: fuse_outermost_reduction +// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: tensor<10x17xf32> +// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: tensor<10xf32> +func @fuse_outermost_reduction(%arg0: tensor<10x17xf32>, + %arg1: tensor<10xf32>) -> tensor<10xf32> { + %cst = arith.constant 0.000000e+00 : f32 + %0 = linalg.fill(%cst, %arg0) : f32, tensor<10x17xf32> -> tensor<10x17xf32> + + // Cannot fuse the output fill since the reduction loop is the outermost loop. + // CHECK: %[[T0:.*]] = linalg.fill(%{{.*}}, %[[ARG1]]) + %1 = linalg.fill(%cst, %arg1) : f32, tensor<10xf32> -> tensor<10xf32> + + // CHECK: scf.for %[[IV0:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG2:.*]] = %[[T0]] + // CHECK: scf.for %[[IV1:[0-9a-zA-Z]*]] = {{.*}} iter_args(%[[ARG3:.*]] = %[[ARG2]] + + // Check the input fill has been fused. + // CHECK: %[[T1:.*]] = tensor.extract_slice %[[ARG0]] + // CHECK-SAME: %[[IV1]], %[[IV0]] + // CHECK: %[[T2:.*]] = linalg.fill(%{{.*}}, %[[T1]]) + // CHECK: %[[T3:.*]] = tensor.extract_slice %[[ARG3]] + // CHECK-SAME: %[[IV1]] + // CHECK: linalg.generic {{.*}} ins(%[[T2]] {{.*}} outs(%[[T3]] + %2 = linalg.generic {indexing_maps = [#map0, #map1], iterator_types = ["parallel", "reduction"]} ins(%0 : tensor<10x17xf32>) outs(%1 : tensor<10xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + %3 = arith.addf %arg2, %arg3 : f32 + linalg.yield %3 : f32 + } -> tensor<10xf32> + return %2 : tensor<10xf32> +} + +// ----- + // CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (8, -d0 - d1 + 17)> // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, -d1 - d2 + 17)>