diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1590,56 +1590,6 @@ //===---------------------------------------------------------------------===// namespace { -/// Forces `outs` operands of linalg operations to use `linalg.init_tensor` if -/// the value of the `outs` operand is not used within the op. This is only -/// implemented for `linalg.generic` operations for now, but should hold for all -/// linalg structured ops. -struct RemoveOutsDependency : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp op, - PatternRewriter &rewriter) const override { - rewriter.startRootUpdate(op); - bool modifiedOutput = false; - Location loc = op.getLoc(); - for (OpOperand *opOperand : op.getOutputOperands()) { - if (!op.payloadUsesValueFromOperand(opOperand)) { - Value operandVal = opOperand->get(); - auto operandType = operandVal.getType().dyn_cast(); - if (!operandType) - continue; - - // If outs is sparse, leave it to the sparse compiler. - if (sparse_tensor::getSparseTensorEncoding(operandVal.getType())) - continue; - - // If outs is already an `init_tensor` operation, nothing to do. - auto definingOp = operandVal.getDefiningOp(); - if (definingOp) - continue; - modifiedOutput = true; - SmallVector dynamicDims; - for (const auto &dim : llvm::enumerate(operandType.getShape())) { - if (dim.value() != ShapedType::kDynamicSize) - continue; - dynamicDims.push_back(rewriter.createOrFold( - loc, operandVal, dim.index())); - } - Value initTensor = rewriter.create( - loc, dynamicDims, operandType.getShape(), - operandType.getElementType()); - op->setOperand(opOperand->getOperandNumber(), initTensor); - } - } - if (!modifiedOutput) { - rewriter.cancelRootUpdate(op); - return failure(); - } - rewriter.finalizeRootUpdate(op); - return success(); - } -}; - /// Fold linalg.fill into linalg.generic struct FoldFillWithGenericOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1686,8 +1636,7 @@ const ControlFusionFn &controlElementwiseOpsFusion) { auto *context = patterns.getContext(); patterns.add(context, controlElementwiseOpsFusion); - patterns.add(context); + patterns.add(context); } //===---------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -707,42 +707,6 @@ // ----- -#map = affine_map<(d0, d1) -> (d0, d1)> -#trait = { - indexing_maps = [#map, #map], - iterator_types = ["parallel", "parallel"] -} -func.func @break_outs_dependency(%arg0 : tensor) -> tensor -{ - %0 = linalg.generic #trait ins(%arg0 : tensor) outs(%arg0 : tensor) { - ^bb0(%arg1 : f32, %arg2 : f32) : - %1 = arith.addf %arg1, %arg1 : f32 - linalg.yield %1 : f32 - } -> tensor - %2 = linalg.generic #trait ins(%0 : tensor) outs(%0 : tensor) { - ^bb0(%arg1 : f32, %arg2 : f32) : - %3 = arith.mulf %arg1, %arg1 : f32 - linalg.yield %3 : f32 - } -> tensor - return %2 : tensor -} -// CHECK: func @break_outs_dependency( -// CHECK-SAME: %[[ARG0:.+]]: tensor) -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] -// CHECK: %[[GENERIC1:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor) -// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[GENERIC1]], %[[C0]] -// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[GENERIC1]], %[[C1]] -// CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] -// CHECK: %[[RESULT:.+]] = linalg.generic -// CHECK-SAME: outs(%[[INIT]] : tensor) - -// ----- - func.func @fuse_scalar_constant(%arg0 : tensor) -> (tensor, tensor) { %cst = arith.constant 4.0 : f32 %c42 = arith.constant 42 : i32