diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -913,35 +913,12 @@ return success(); } }; - -/// Fold linalg.fill into linalg.generic -struct FoldFillWithGenericOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - bool fillFound = false; - Block &payload = genericOp.region().front(); - for (OpOperand *opOperand : genericOp.getInputOperands()) { - FillOp fillOp = opOperand->get().getDefiningOp(); - if (fillOp) { - fillFound = true; - payload.getArgument(opOperand->getOperandNumber()) - .replaceAllUsesWith(fillOp.value()); - } - } - // fail if there are no FillOps to fold. - return success(fillFound); - } -}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.add(context); + DeadArgsGenericOpInputs>(context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -2215,8 +2215,31 @@ return success(); } }; -} // namespace +/// Fold linalg.fill into linalg.generic +struct FoldFillWithGenericOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (!genericOp.hasTensorSemantics()) + return failure(); + bool fillFound = false; + Block &payload = genericOp.region().front(); + for (OpOperand *opOperand : genericOp.getInputOperands()) { + FillOp fillOp = opOperand->get().getDefiningOp(); + if (fillOp) { + auto arg = payload.getArgument(opOperand->getOperandNumber()); + if (arg.use_empty()) + continue; + fillFound = true; + arg.replaceAllUsesWith(fillOp.value()); + } + } + return success(fillFound); + } +}; +} // namespace //===---------------------------------------------------------------------===// // Methods that add patterns described in this file to a pattern list. //===---------------------------------------------------------------------===// @@ -2261,7 +2284,7 @@ patterns.add(context, options.controlElementwiseOpsFusionFn); - patterns.add(context); + patterns.add(context); populateSparseTensorRewriting(patterns); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -343,59 +343,6 @@ // ----- -// CHECK-LABEL: func @fold_fill_generic_basic -// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { -// CHECK-NOT: linalg.fill -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]] : tensor) -// CHECK-SAME: outs({{.*}} : tensor) { -#map0 = affine_map<(d0) -> (d0)> -func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { - %c0 = arith.constant 0 : index - %cst = arith.constant 7.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = linalg.init_tensor [%0] : tensor - %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor - %3 = linalg.init_tensor [%0] : tensor - %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %5 = arith.addf %arg1, %arg2 : f32 - linalg.yield %5 : f32 - } -> tensor - return %4 : tensor -} - -// ----- - -// CHECK-LABEL: func @fold_fill_generic_mixedaccess -// CHECK-NOT: linalg.fill -// CHECK: %[[GENERIC_OP:.*]] = linalg.generic -// CHECK-NOT: ins -// CHECK-SAME: outs({{.*}} : tensor) { -#map0 = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1, d0)> -func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor) { - %c0 = arith.constant 0 : index - %c1 = arith.constant 0 : index - %cst1 = arith.constant 7.0 : f32 - %cst2 = arith.constant 6.0 : f32 - %0 = tensor.dim %arg0, %c0 : tensor - %1 = tensor.dim %arg0, %c1 : tensor - %2 = linalg.init_tensor [%0, %1] : tensor - %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor) -> tensor - %4 = linalg.init_tensor [%1, %0] : tensor - %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor) -> tensor - %6 = linalg.init_tensor [%0, %1] : tensor - %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor, tensor) outs (%6:tensor) { - ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): - %8 = arith.divf %arg1, %arg2 : f32 - linalg.yield %8 : f32 - } -> tensor - return %7 : tensor -} - -// ----- - // CHECK-LABEL: func @remove_deadargs_generic_basic // CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { // CHECK: %[[GENERIC_OP:.*]] = linalg.generic diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -975,3 +975,56 @@ // CHECK: %[[PRODUCER:.+]] = linalg.generic // CHECK: linalg.generic // CHECK-SAME: ins(%[[PRODUCER]] + +// ----- + +// CHECK-LABEL: func @fold_fill_generic_basic +// CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : tensor) +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0) -> (d0)> +func @fold_fill_generic_basic(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %cst = arith.constant 7.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = linalg.init_tensor [%0] : tensor + %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor) -> tensor + %3 = linalg.init_tensor [%0] : tensor + %4 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types=["parallel"]} ins(%arg0, %2 : tensor, tensor) outs (%3:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %5 = arith.addf %arg1, %arg2 : f32 + linalg.yield %5 : f32 + } -> tensor + return %4 : tensor +} + +// ----- + +// CHECK-LABEL: func @fold_fill_generic_mixedaccess +// CHECK-NOT: linalg.fill +// CHECK: %[[GENERIC_OP:.*]] = linalg.generic +// CHECK-NOT: ins +// CHECK-SAME: outs({{.*}} : tensor) { +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1, d0)> +func @fold_fill_generic_mixedaccess(%arg0: tensor) -> (tensor) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 0 : index + %cst1 = arith.constant 7.0 : f32 + %cst2 = arith.constant 6.0 : f32 + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.fill ins(%cst1 : f32) outs(%2 : tensor) -> tensor + %4 = linalg.init_tensor [%1, %0] : tensor + %5 = linalg.fill ins(%cst2 : f32) outs(%4 : tensor) -> tensor + %6 = linalg.init_tensor [%0, %1] : tensor + %7 = linalg.generic {indexing_maps = [#map0, #map1, #map0], iterator_types=["parallel","parallel"]} ins(%3, %5 : tensor, tensor) outs (%6:tensor) { + ^bb0(%arg1: f32, %arg2: f32, %arg3: f32): + %8 = arith.divf %arg1, %arg2 : f32 + linalg.yield %8 : f32 + } -> tensor + return %7 : tensor +}