Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -76,6 +76,18 @@ RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion); +/// Function type to control generic op dimension collapsing. It is expected +/// to return an array of `ReassociationIndices` representing dimensions that +/// should be merged. +using ControlCollapseFn = + std::function(linalg::GenericOp)>; + +/// Pattern to collapse dimensions in a linalg.generic op. This will collapse +/// tensor operands when needed and expand back the result tensors. +void populateCollapseDimensions( + RewritePatternSet &patterns, + const ControlCollapseFn &controlCollapseDimensions); + /// Patterns to fold an expanding (collapsing) tensor_reshape operation with its /// producer (consumer) generic operation by expanding the dimensionality of the /// loop in the generic op. Index: mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1367,7 +1367,7 @@ /// Implementation of fusion with reshape operation by collapsing dimensions. static FailureOr> collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef foldedIterationDims, - OpOperand *fusableOpOperand, PatternRewriter &rewriter) { + PatternRewriter &rewriter) { // Bail on trivial no-op cases. if (genericOp.getNumLoops() <= 1 || foldedIterationDims.empty() || llvm::all_of(foldedIterationDims, [](ReassociationIndicesRef foldedDims) { @@ -1510,7 +1510,7 @@ Optional> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, - opOperand, rewriter); + rewriter); if (!replacements) { return rewriter.notifyMatchFailure( genericOp, "failed to do the fusion by collapsing transformation"); @@ -1525,6 +1525,37 @@ private: ControlFusionFn controlFoldingReshapes; }; + +/// Pattern to collapse dimensions. +class CollapseLinalgDimensions : public OpRewritePattern { +public: + CollapseLinalgDimensions(MLIRContext *context, + ControlCollapseFn collapseDimensions, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlCollapseDimension(std::move(collapseDimensions)) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + SmallVector collapsableIterationDims = + controlCollapseDimension(genericOp); + if (collapsableIterationDims.empty()) + return failure(); + + Optional> replacements = collapseGenericOpIterationDims( + genericOp, collapsableIterationDims, rewriter); + if (!replacements) { + return rewriter.notifyMatchFailure(genericOp, + "failed to collpase dimensions"); + } + rewriter.replaceOp(genericOp, *replacements); + return success(); + } + +private: + ControlCollapseFn controlCollapseDimension; +}; + } // namespace //===---------------------------------------------------------------------===// @@ -1743,6 +1774,13 @@ RemoveOutsDependency>(context); } +void mlir::linalg::populateCollapseDimensions( + RewritePatternSet &patterns, + const ControlCollapseFn &controlCollapseDimensions) { + patterns.add(patterns.getContext(), + controlCollapseDimensions); +} + //===---------------------------------------------------------------------===// // Passes //===---------------------------------------------------------------------===// Index: mlir/test/Dialect/Linalg/collapse-dim.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/Linalg/collapse-dim.mlir @@ -0,0 +1,55 @@ +// RUN: mlir-opt %s -test-linalg-elementwise-fusion-patterns=collapse-dimensions-control=2,3 -split-input-file | FileCheck %s + +func.func @collapse_reduction( + %arg0: tensor<2x32x10x4096xf32>, %arg1: tensor<2x32xf32>) -> tensor<2x32xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1)>], + iterator_types = ["parallel", "parallel", "reduction", "reduction"]} + ins(%arg0 : tensor<2x32x10x4096xf32>) outs(%arg1 : tensor<2x32xf32>) { + ^bb0(%arg3: f32, %arg4: f32): + %1 = arith.addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor<2x32xf32> + return %0 : tensor<2x32xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> + +// CHECK-LABEL: func @collapse_reduction +// CHECK: %[[T:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32> +// CHECK: linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]} +// CHECK-SAME: ins(%[[T]] : tensor<2x32x40960xf32>) outs(%{{.*}} : tensor<2x32xf32>) { +// CHECK: } -> tensor<2x32xf32> + +// ----- + +func.func @collapse_parallel( + %arg0: tensor<32x2x10x4096xf32>, %arg1: tensor<2x32x10x4096xf32>) -> tensor<2x32x10x4096xf32> { + %0 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2, d3) -> (d1, d0, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<32x2x10x4096xf32>) outs(%arg1 : tensor<2x32x10x4096xf32>) { + ^bb0(%arg3: f32, %arg4: f32): + %1 = arith.addf %arg3, %arg4 : f32 + linalg.yield %1 : f32 + } -> tensor<2x32x10x4096xf32> + return %0 : tensor<2x32x10x4096xf32> +} + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> + +// CHECK-LABEL: func @collapse_parallel +// CHECK-DAG: %[[S:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<32x2x10x4096xf32> into tensor<32x2x40960xf32> +// CHECK-DAG: %[[D:.*]] = tensor.collapse_shape %{{.*}} {{\[}}[0], [1], [2, 3]] : tensor<2x32x10x4096xf32> into tensor<2x32x40960xf32> +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP1]]], +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]} +// CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) { +// CHECK: } -> tensor<2x32x40960xf32> +// CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32> Index: mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp =================================================================== --- mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp +++ mlir/test/lib/Dialect/Linalg/TestLinalgElementwiseFusion.cpp @@ -99,6 +99,9 @@ "fusion patterns that " "collapse the iteration space of the consumer"), llvm::cl::init(false)}; + ListOption collapseDimensions{ + *this, "collapse-dimensions-control", + llvm::cl::desc("Test controlling dimension collapse pattern")}; void runOnOperation() override { MLIRContext *context = &this->getContext(); @@ -179,6 +182,19 @@ linalg::populateFoldReshapeOpsByCollapsingPatterns(patterns, controlFn); (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); } + + if (!collapseDimensions.empty()) { + SmallVector dims(collapseDimensions.begin(), + collapseDimensions.end()); + linalg::ControlCollapseFn collapseFn = [&dims](linalg::GenericOp op) { + SmallVector reassociations; + reassociations.emplace_back(dims); + return reassociations; + }; + RewritePatternSet patterns(context); + linalg::populateCollapseDimensions(patterns, collapseFn); + (void)applyPatternsAndFoldGreedily(funcOp.getBody(), std::move(patterns)); + } } };