diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -901,8 +901,21 @@ const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc = false); -/// Collapses dimensions of linalg.generic operation. It also collapses inputs -/// before the op and expands outputs after the op. +/// Return `true` if a given sequence of dimensions are contiguous in the +/// range of the specified indexing map. +bool isDimSequencePreserved(AffineMap map, ReassociationIndicesRef dimSequence); +/// Return `true` if all sequences of dimensions specified in `dimSequences` are +/// contiguous in all the ranges of the `maps`. +bool areDimSequencesPreserved(ArrayRef maps, + ArrayRef dimSequences); + +/// Collapses dimensions of linalg.generic operation. A precondition to +/// calling this method is that for each list in `foldedIterationDim`, the +/// sequence of dimensions is contiguous in domains of all `indexing_maps` of +/// the `genericOp`. This can be checked using `areDimSequencePreserved` method. +/// When valid, the method also collapses the operands of the op. Returns +/// replacement values of the results of the original `genericOp` by inserting +/// reshapes to get back values of compatible types. FailureOr> collapseGenericOpIterationDims( GenericOp genericOp, ArrayRef foldedIterationDims, RewriterBase &rewriter); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1004,8 +1004,8 @@ /// For a given `dimSequence`, check if the sequence is conserved in the /// `indexingMap`. `indexingMap` is expected to be a projected permutation. /// Non-existence of the sequence returns true as well. -static bool isDimSequencePreserved(AffineMap indexingMap, - ReassociationIndicesRef dimSequence) { +bool mlir::linalg::isDimSequencePreserved(AffineMap indexingMap, + ReassociationIndicesRef dimSequence) { assert(!dimSequence.empty() && "expected non-empty list for dimension sequence"); assert(indexingMap.isProjectedPermutation() && @@ -1045,6 +1045,15 @@ return true; } +bool mlir::linalg::areDimSequencesPreserved( + ArrayRef maps, ArrayRef dimSequences) { + return llvm::all_of(maps, [&](AffineMap map) { + return llvm::all_of(dimSequences, [&](ReassociationIndicesRef dimSequence) { + return isDimSequencePreserved(map, dimSequence); + }); + }); +} + // Return the list of dimensions of the iteration domain that can be // collapsed to allow for fusion with the a producer that is an expand_shape // operation. If all dimensions created by expansion can be collapsed in the @@ -1592,6 +1601,13 @@ if (collapsableIterationDims.empty()) return failure(); + // Check if the specified list of dimensions to collapse is a valid list. + if (!areDimSequencesPreserved(genericOp.getIndexingMapsArray(), + collapsableIterationDims)) { + return rewriter.notifyMatchFailure( + genericOp, "specified dimensions cannot be collapsed"); + } + std::optional> replacements = collapseGenericOpIterationDims(genericOp, collapsableIterationDims, rewriter); diff --git a/mlir/test/Dialect/Linalg/collapse-dim.mlir b/mlir/test/Dialect/Linalg/collapse-dim.mlir --- a/mlir/test/Dialect/Linalg/collapse-dim.mlir +++ b/mlir/test/Dialect/Linalg/collapse-dim.mlir @@ -53,3 +53,20 @@ // CHECK-SAME: ins(%[[S]] : tensor<32x2x40960xf32>) outs(%[[D]] : tensor<2x32x40960xf32>) { // CHECK: } -> tensor<2x32x40960xf32> // CHECK: tensor.expand_shape %[[R]] {{\[}}[0], [1], [2, 3]] : tensor<2x32x40960xf32> into tensor<2x32x10x4096xf32> + +// ----- + +#map = affine_map<(d0, d1, d2, d3) -> (d3, d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +func.func @uncollapsable(%arg0 : tensor<41x3x1x57xf32>, %arg1 : tensor<3x1x57x41xf32>) -> tensor<3x1x57x41xf32> { + %0 = linalg.generic { + indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<41x3x1x57xf32>) outs(%arg1 : tensor<3x1x57x41xf32>) { + ^bb0(%in: f32, %out: f32): + linalg.yield %in : f32 + } -> tensor<3x1x57x41xf32> + return %0 : tensor<3x1x57x41xf32> +} +// CHECK-LABEL: func @uncollapsable( +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"]