diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -75,6 +75,12 @@ let summary = "Fold TensorReshapeOps with generic/indexed generic ops by " "linearization"; let constructor = "mlir::createFoldReshapeOpsByLinearizationPass()"; + let options = [ + Option<"allowFoldingUnitDimReshapes", "allow-folding-unit-dim-reshapes", + "bool", /*default=*/"false", + "Allow fusing linalg.tensor_reshape ops that performs unit " + "dimension collapsing"> + ]; let dependentDialects = ["AffineDialect", "memref::MemRefDialect"]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -320,27 +320,27 @@ /// %0 = op ... : tensor /// with output index_map /// `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>` +template static AffineMap linearizeCollapsedDims(AffineMap sourceMap, - ArrayRef sourceShape, - ArrayRef reassociationMaps) { + TensorReshapeOp reshapeOp) { + constexpr bool isExpanding = + std::is_same::value; + ArrayRef sourceShape = + (isExpanding ? reshapeOp.getResultType().getShape() + : reshapeOp.getSrcType().getShape()); SmallVector resultExprs; - resultExprs.reserve(reassociationMaps.size()); ArrayRef sourceExprs = sourceMap.getResults(); MLIRContext *context = sourceMap.getContext(); // Compute the result exprs based on the reassociation maps. - for (AffineMap map : reassociationMaps) { - ArrayRef collapsedDims = map.getResults(); + for (auto &indices : reshapeOp.getReassociationIndices()) { // Assume that they are in-order and contiguous (already checked in // verifier). - assert(!collapsedDims.empty()); - unsigned startDim = - collapsedDims.front().cast().getPosition(); + assert(!indices.empty()); SmallVector sizes; SmallVector dimExprs; - for (auto en : - llvm::zip(sourceShape.slice(startDim, collapsedDims.size()), - sourceExprs.slice(startDim, collapsedDims.size()))) { + for (auto en : llvm::zip(sourceShape.slice(indices[0], indices.size()), + sourceExprs.slice(indices[0], indices.size()))) { if (std::get<0>(en) == 1) continue; sizes.push_back(std::get<0>(en)); @@ -359,7 +359,7 @@ // divs in the indexing maps of the fused op which would make it non-invertible. static bool isTensorReshapeOpFoldableByLinearization( TensorExpandShapeOp expandOp, AffineMap useIndexMap, bool asProducer) { - if (!asProducer && expandOp.getResultType().hasStaticShape()) + if (!asProducer) return false; return useIndexMap.isPermutation(); } @@ -368,23 +368,26 @@ // consumer). static bool isTensorReshapeOpFoldableByLinearization( TensorCollapseShapeOp collapseOp, AffineMap useIndexMap, bool asProducer) { - if (asProducer && collapseOp.getSrcType().hasStaticShape()) + if (asProducer) return false; return useIndexMap.isPermutation(); } /// Check if the reshape operation is only expansion into/collapsing of /// unit-dimension. -static bool isUnitDimExpansionOnly(ArrayRef expandedShape, - ArrayRef reassociation) { - for (auto &map : reassociation) { +template +static bool isUnitDimExpansionOnly(TensorReshapeOp reshapeOp) { + constexpr bool isExpanding = + std::is_same::value; + ArrayRef expandedShape = + (isExpanding ? reshapeOp.getResultType().getShape() + : reshapeOp.getSrcType().getShape()); + for (auto &indices : reshapeOp.getReassociationIndices()) { unsigned numUnitDims = 0; - for (AffineExpr expr : map.getResults()) { - unsigned position = expr.cast().getPosition(); + for (int64_t position : indices) if (expandedShape[position] == 1) numUnitDims++; - } - if (numUnitDims != map.getNumResults() - 1) + if (numUnitDims != indices.size() - 1) return false; } return true; @@ -818,14 +821,10 @@ if (!reshapeOp) continue; - RankedTensorType returnType = reshapeOp.getResultType(); - if (!isTensorReshapeOpFoldableByLinearization( reshapeOp, genericOp.getTiedIndexingMap(en.value()), /*asProducer =*/true) || - (foldUnitDimReshapesOnly && - !isUnitDimExpansionOnly(returnType.getShape(), - reshapeOp.getReassociationMaps()))) + (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) continue; // Compute the fused operands list, @@ -842,8 +841,10 @@ auto invMap = inversePermutation(fusedIndexMaps[en.index()]); // Compute the indexing map to use for the result of the producer. - AffineMap modifiedMap = linearizeCollapsedDims( - invMap, returnType.getShape(), reshapeOp.getReassociationMaps()); + AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); + // The modified map cannot have symbols. + if (modifiedMap.getNumSymbols()) + return failure(); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) return failure(); @@ -1081,9 +1082,7 @@ reshapeOp, producer.getTiedIndexingMap(producer.getOutputOperand(0)), /*asProducer =*/false) || - (foldUnitDimReshapesOnly && - !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps()))) + (foldUnitDimReshapesOnly && !isUnitDimExpansionOnly(reshapeOp))) return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. @@ -1093,9 +1092,7 @@ producer.getTiedIndexingMap(producer.getOutputOperand(0))); // Compute the indexing map to use for the operand of the producer. - AffineMap modifiedMap = - linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(), - reshapeOp.getReassociationMaps()); + AffineMap modifiedMap = linearizeCollapsedDims(invMap, reshapeOp); for (AffineExpr expr : modifiedMap.getResults()) { if (!expr.isPureAffine()) { return rewriter.notifyMatchFailure( @@ -1144,8 +1141,7 @@ if (!producer || producer.getNumOutputs() != 1 || !isFusableWithReshapeByDimExpansion(producer, producer.getOutputOperand(0)) || - isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), - reshapeOp.getReassociationMaps())) + isUnitDimExpansionOnly(reshapeOp)) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion( producer, reshapeOp, producer.getOutputOperand(0), rewriter); @@ -1248,12 +1244,10 @@ const OpOperand &consumer) { auto expandShapeOp = producer.getDefiningOp(); if (expandShapeOp) - return !isUnitDimExpansionOnly(expandShapeOp.getSrcType().getShape(), - expandShapeOp.getReassociationMaps()); + return !isUnitDimExpansionOnly(expandShapeOp); auto collapseShapeOp = producer.getDefiningOp(); - return !isUnitDimExpansionOnly(collapseShapeOp.getSrcType().getShape(), - collapseShapeOp.getReassociationMaps()); + return !isUnitDimExpansionOnly(collapseShapeOp); } namespace { @@ -1312,6 +1306,9 @@ Operation *op = getOperation(); RewritePatternSet patterns(op->getContext()); populateFoldReshapeOpsByLinearizationPatterns(patterns); + if (allowFoldingUnitDimReshapes) { + populateFoldUnitDimsReshapeOpsByLinearizationPatterns(patterns); + } (void)applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns)); } }; diff --git a/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/reshape_linearization_fusion_with_unit_dims.mlir @@ -0,0 +1,52 @@ +// RUN: mlir-opt -linalg-fold-reshape-ops-by-linearization=allow-folding-unit-dim-reshapes -split-input-file %s | FileCheck %s + +#map = affine_map<(d0, d1) -> (d0, d1)> +func @do_not_fold1(%arg0 : tensor, %arg1 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = memref.dim %arg0, %c0 : tensor + %1 = memref.dim %arg0, %c1 : tensor + %2 = linalg.init_tensor [%0, %1] : tensor + %3 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %4 = addf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = linalg.tensor_expand_shape %3 [[0], [1, 2]] : tensor into tensor + return %4 : tensor +} +// CHECK-LABEL: func @do_not_fold1 +// CHECK: %[[VAL:.+]] = linalg.generic +// CHECK: linalg.tensor_expand_shape %[[VAL]] + +// ----- + +#map = affine_map<(d0, d1) -> (d0, d1)> +func @do_not_fold2(%arg0 : tensor, %arg1 : tensor) -> tensor +{ + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = linalg.tensor_collapse_shape %arg0 [[0], [1, 2]] : tensor into tensor + %1 = memref.dim %arg1, %c0 : tensor + %2 = memref.dim %arg1, %c1 : tensor + %3 = linalg.init_tensor [%1, %2] : tensor + %4 = linalg.generic { + indexing_maps = [#map, #map, #map], + iterator_types = ["parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%3 : tensor) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32): + %4 = addf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + return %4 : tensor +} +// CHECK-LABEL: func @do_not_fold2 +// CHECK: %[[VAL:.+]] = linalg.tensor_collapse_shape +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[VAL]], %{{.+}} : tensor, tensor) diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1497,9 +1497,27 @@ cc_library( name = "SparseTensor", - srcs = glob(["lib/Dialect/SparseTensor/IR/*.cpp"]), + srcs = glob([ + "lib/Dialect/SparseTensor/IR/*.cpp", + ]), hdrs = ["include/mlir/Dialect/SparseTensor/IR/SparseTensor.h"], includes = ["include"], + deps = [ + ":IR", + ":SideEffectInterfaces", + ":SparseTensorAttrDefsIncGen", + ":SparseTensorOpsIncGen", + ":SparseTensorUtils", + ":StandardOps", + "//llvm:Support", + ], +) + +cc_library( + name = "SparseTensorUtils", + srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), + includes = ["include"], deps = [ ":IR", ":SideEffectInterfaces", @@ -1535,17 +1553,6 @@ ], ) -cc_library( - name = "SparseTensorUtils", - srcs = glob(["lib/Dialect/SparseTensor/Utils/*.cpp"]), - hdrs = glob(["include/mlir/Dialect/SparseTensor/Utils/*.h"]), - includes = ["include"], - deps = [ - ":IR", - "//llvm:Support", - ], -) - td_library( name = "StdOpsTdFiles", srcs = [