diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -88,6 +88,10 @@ /// This is effectively DCE for a linalg op. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); +/// Patterns to promote inputs to outputs and remove unused inputs of +/// `linalg.generic` ops. +void populateEraseUnnecessaryInputsPatterns(RewritePatternSet &patterns); + /// Function type to control generic op dimension collapsing. It is expected /// to return an array of `ReassociationIndices` representing dimensions that /// should be merged. diff --git a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/EraseUnusedOperandsAndResults.cpp @@ -56,7 +56,9 @@ struct DeduplicateAndRemoveDeadOperandsAndResults : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + DeduplicateAndRemoveDeadOperandsAndResults(MLIRContext *ctx, + bool removeOutputs) + : OpRewritePattern(ctx), removeOutputs(removeOutputs) {} LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { @@ -120,6 +122,9 @@ } private: + /// If unset, outputs are not modified by this pattern. + bool removeOutputs; + // Deduplicate input operands, and return the // - Mapping from operand position in the original op, to operand position in // the canonicalized op. @@ -176,9 +181,9 @@ llvm::SmallDenseMap origToNewPos; llvm::SmallDenseMap, unsigned> dedupedOutpts; - // If the op doesnt have tensor semantics, keep all the outputs as - // preserved. - if (!genericOp.hasTensorSemantics()) { + // If the op doesn't have tensor semantics or outputs should not be removed, + // keep all the outputs as preserved. + if (!genericOp.hasTensorSemantics() || !removeOutputs) { for (const auto &en : llvm::enumerate(genericOp.getDpsInitOperands())) { origToNewPos[en.index()] = newOutputOperands.size(); newOutputOperands.push_back(en.value()->get()); @@ -353,10 +358,69 @@ return failure(); } }; + +/// Fold uses of duplicate inputs in the body of a linalg.generic. E.g.: +/// ``` +/// linalg.generic ins(%a, %b, %a, %b) outs(%a) +/// ^bb0(%in0, %in1, %in2, %in3, %out1) +/// ``` +/// Assuming that all %a and %b have the same index map: +/// * All uses of %in0 and %in2 are replaced with %out1 +/// * All uses of %in1 are replaced with %in3 +/// This pattern can enable additional canonicalizations: In the above example, +/// %in0, %in1 and %in3 have no uses anymore and their corresponding operands +/// can be folded away. This pattern does not modify uses of output block args. +struct FoldDuplicateInputBbArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // Find replacement bbArgs for all input bbArg. + DenseMap replacements; + for (int i = 0; i < genericOp.getNumDpsInputs(); ++i) { + // Skip bbArgs that have no uses. + if (genericOp.getBody()->getArgument(i).getUses().empty()) + continue; + // Find replacement bbArg. This can be an input or an output bbArg. + for (int j = genericOp->getNumOperands() - 1; j > i; --j) { + if (genericOp->getOperand(i) == genericOp->getOperand(j) && + genericOp.getIndexingMapsArray()[i] == + genericOp.getIndexingMapsArray()[j]) { + replacements[i] = j; + break; + } + } + } + + // Stop here if no replacements were found. + if (replacements.empty()) + return failure(); + + // Rewrite the op. + rewriter.updateRootInPlace(genericOp, [&]() { + for (auto [before, after] : replacements) { + BlockArgument bbArg = genericOp.getBody()->getArgument(before); + BlockArgument replacement = genericOp.getBody()->getArgument(after); + rewriter.replaceAllUsesWith(bbArg, replacement); + } + }); + + return success(); + } +}; + } // namespace void mlir::linalg::populateEraseUnusedOperandsAndResultsPatterns( RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); + patterns.insert( + patterns.getContext(), /*removeOutputs=*/true); + patterns.insert(patterns.getContext()); +} + +void mlir::linalg::populateEraseUnnecessaryInputsPatterns( + RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext(), /*removeOutputs=*/false); + patterns.insert(patterns.getContext()); } diff --git a/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir --- a/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir +++ b/mlir/test/Dialect/Linalg/erase-unused-operands-and-results.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unused-operands-and-results | FileCheck %s +// RUN: mlir-opt %s -split-input-file -test-linalg-transform-patterns=test-erase-unnecessary-inputs | FileCheck %s --check-prefix=CHECK-INPUT // CHECK-LABEL: func @remove_deadargs_generic_basic // CHECK-SAME: (%[[ARG0:.*]]: tensor) -> tensor { @@ -493,3 +494,29 @@ // CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] // CHECK-SAME: outs(%[[ARG0]], %[[INIT]] : // CHECK: return %[[GENERIC]]#0 + + +// ----- + +// CHECK-INPUT-LABEL: func @remove_unnecessary_input( +// CHECK-INPUT-SAME: %[[a:.*]]: tensor, %[[b:.*]]: tensor +#map = affine_map<(d0) -> (d0)> +func.func @remove_unnecessary_input(%a: tensor, %b: tensor) + -> tensor +{ + // CHECK-INPUT: %[[result:.*]] = linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel"]} + // CHECK-INPUT-SAME: ins(%[[a]] : tensor) outs(%[[b]] : tensor) { + // CHECK-INPUT: ^bb0(%[[in:.*]]: f32, %[[out:.*]]: f32): + // CHECK-INPUT: %[[add:.*]] = arith.addf %[[in]], %[[out]] + // CHECK-INPUT: linalg.yield %[[add]] + // CHECK-INPUT: } -> tensor + // CHECK-INPUT: return %[[result]] + %0 = linalg.generic + {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} + ins(%a, %b : tensor, tensor) outs(%b : tensor) { + ^bb0(%in: f32, %in_2: f32, %out: f32): + %16 = arith.addf %in, %in_2 : f32 + linalg.yield %16 : f32 + } -> tensor + return %0 : tensor +} diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -113,6 +113,10 @@ *this, "test-erase-unused-operands-and-results", llvm::cl::desc("Test patterns to erase unused operands and results"), llvm::cl::init(false)}; + Option testEraseUnnecessaryInputs{ + *this, "test-erase-unnecessary-inputs", + llvm::cl::desc("Test patterns to erase unnecessary inputs"), + llvm::cl::init(false)}; }; } // namespace @@ -185,6 +189,12 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyEraseUnnecessaryInputs(func::FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + populateEraseUnnecessaryInputsPatterns(patterns); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + /// Apply transformations specified as patterns. void TestLinalgTransforms::runOnOperation() { if (testPatterns) @@ -205,6 +215,8 @@ return applySwapExtractSliceWithFillPattern(getOperation()); if (testEraseUnusedOperandsAndResults) return applyEraseUnusedOperandsAndResultsPatterns(getOperation()); + if (testEraseUnnecessaryInputs) + return applyEraseUnnecessaryInputs(getOperation()); } namespace mlir {