diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -1310,6 +1310,52 @@ } }; +/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if +/// the value of the `outs` operand is not used within the op. This is only +/// implemented for `linalg.generic` operations for now, but should hold for all +/// linalg structured ops. +struct RemoveOutsDependency : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override { + rewriter.startRootUpdate(op); + bool modifiedOutput = false; + Location loc = op.getLoc(); + for (OpOperand &opOperand : op.getOutputOpOperands()) { + if (!op.payloadUsesValueFromOpOperand(&opOperand)) { + Value operandVal = opOperand.get(); + auto operandType = operandVal.getType().dyn_cast(); + if (!operandType) + continue; + + // If outs is already an `init_tensor` operation, nothing to do. + auto definingOp = operandVal.getDefiningOp(); + if (definingOp) + continue; + modifiedOutput = true; + SmallVector dynamicDims; + for (auto dim : llvm::enumerate(operandType.getShape())) { + if (dim.value() != ShapedType::kDynamicSize) + continue; + dynamicDims.push_back(rewriter.createOrFold( + loc, operandVal, dim.index())); + } + Value initTensor = rewriter.create( + loc, dynamicDims, operandType.getShape(), + operandType.getElementType()); + op->setOperand(opOperand.getOperandNumber(), initTensor); + } + } + if (!modifiedOutput) { + rewriter.cancelRootUpdate(op); + return failure(); + } + rewriter.finalizeRootUpdate(op); + return success(); + } +}; + } // namespace void mlir::linalg::populateFoldReshapeOpsByLinearizationPatterns( @@ -1339,6 +1385,7 @@ auto *context = patterns.getContext(); patterns.add( context, options.controlElementwiseOpsFusionFn); + patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); AffineApplyOp::getCanonicalizationPatterns(patterns, context); diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -662,3 +662,39 @@ } -> tensor<3xf32> return %result : tensor<3xf32> } + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +#trait = { + indexing_maps = [#map, #map], + iterator_types = ["parallel", "parallel"] +} +func @break_outs_dependency(%arg0 : tensor) -> tensor +{ + %0 = linalg.generic #trait ins(%arg0 : tensor) outs(%arg0 : tensor) { + ^bb0(%arg1 : f32, %arg2 : f32) : + %1 = addf %arg1, %arg1 : f32 + linalg.yield %1 : f32 + } -> tensor + %2 = linalg.generic #trait ins(%0 : tensor) outs(%0 : tensor) { + ^bb0(%arg1 : f32, %arg2 : f32) : + %3 = mulf %arg1, %arg1 : f32 + linalg.yield %3 : f32 + } -> tensor + return %2 : tensor +} +// CHECK: func @break_outs_dependency( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK: %[[GENERIC1:.+]] = linalg.generic +// CHECK-SAME: outs(%[[INIT]] : tensor) +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[GENERIC1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[GENERIC1]], %[[C1]] +// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK-SAME: outs(%[[INIT]] : tensor) diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -1,6 +1,5 @@ -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file -verify-each=0 | FileCheck %s -// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file -verify-each=0 | FileCheck %s --check-prefix=FOLDUNITDIM - +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=false" -split-input-file | FileCheck %s +// RUN: mlir-opt %s -linalg-fusion-for-tensor-ops="allow-folding-unit-dim-reshapes=true" -split-input-file | FileCheck %s --check-prefix=FOLDUNITDIM #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> func @generic_op_reshape_producer_fusion(%arg0 : tensor, @@ -30,13 +29,11 @@ // CHECK-SAME: [0], [1, 2], [3] // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [0], [1], [2, 3] -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T0]] -// CHECK-SAME: [0], [1], [2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP5]], #[[MAP6]], #[[MAP6]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[ARG0]], %[[T1]] : tensor, tensor) -// CHECK-SAME: outs(%[[T2]] : tensor) +// CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: %[[T4:.+]] = linalg.tensor_reshape %[[T3]] // CHECK-SAME: [0], [1], [2, 3] // CHECK-SAME: tensor into tensor @@ -73,13 +70,11 @@ // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [0], [1, 2, 3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]], #[[MAP2]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) -// CHECK-SAME: outs(%[[T2]] : tensor) +// CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: return %[[T3]] : tensor @@ -115,13 +110,11 @@ // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor<3x4x?x?xf32> -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0, 1], [2], [3, 4, 5] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP8]], #[[MAP9]], #[[MAP10]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor<3x4x?x?x2x?xf32>, tensor<3x4x?x?xf32>) -// CHECK-SAME: outs(%[[T2]] : tensor) +// CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: return %[[T3]] : tensor // ----- @@ -417,13 +410,11 @@ // CHECK: %[[T1:.+]] = linalg.tensor_reshape %[[ARG1]] // CHECK-SAME: [0, 1, 2], [3] // CHECK-SAME: tensor into tensor -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[ARG0]] -// CHECK-SAME: [0], [1, 2, 3] // CHECK: %[[T3:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP4]], #[[MAP5]]] // CHECK-SAME: ["parallel", "parallel", "parallel", "parallel"] // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) -// CHECK-SAME: outs(%[[T2]] : tensor) +// CHECK-SAME: outs(%{{.+}} : tensor) // CHECK: return %[[T3]] : tensor // ----- @@ -501,8 +492,7 @@ // FOLDUNITDIM-SAME: %[[ARG0:.+]]: tensor<1x?x1x2x1x4xf32> // FOLDUNITDIM-SAME: %[[ARG1:.+]]: tensor // FOLDUNITDIM-DAG: %[[RESHAPE:.+]] = linalg.tensor_reshape %[[ARG1]] -// FOLDUNITDIM-DAG: %[[INIT:.+]] = linalg.init_tensor [1, %{{.+}}, 1, 2, 1, 4] // FOLDUNITDIM: linalg.generic // FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) -// FOLDUNITDIM-SAME: outs(%[[INIT]] : tensor<1x?x1x2x1x4xf32>) +// FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>)