diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -72,6 +72,15 @@ void populateFoldReshapeOpsByLinearizationPatterns( MLIRContext *context, OwningRewritePatternList &patterns); +/// Patterns to fold a collapsing (expanding) tensor_reshape operation with its +/// producer (consumer) generic/indexed_generic operation by linearizing the +/// indexing map used to access the source (target) of the reshape operation in +/// the generic/indexed_generic operation. The patterns are applied only when +/// the tensor reshape involved is collapsing (introducing) unit-extent +/// dimensions. +void populateFoldUnitDimsReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns); + /// Patterns for fusing linalg operation on tensors. void populateLinalgTensorOpsFusionPatterns(MLIRContext *context, OwningRewritePatternList &patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -497,6 +497,7 @@ ReplaceUnitExtentTensors>(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); + populateFoldUnitDimsReshapeOpsByLinearizationPatterns(context, patterns); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -302,9 +302,18 @@ assert(!collapsedDims.empty()); unsigned startDim = collapsedDims.front().cast().getPosition(); - AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr( - sourceShape.slice(startDim, collapsedDims.size()), - sourceExprs.slice(startDim, collapsedDims.size()), context); + SmallVector sizes; + SmallVector dimExprs; + for (auto en : + llvm::zip(sourceShape.slice(startDim, collapsedDims.size()), + sourceExprs.slice(startDim, collapsedDims.size()))) { + if (std::get<0>(en) == 1) + continue; + sizes.push_back(std::get<0>(en)); + dimExprs.push_back(std::get<1>(en)); + } + AffineExpr linearizedExpr = + makeCanonicalStridedLayoutExpr(sizes, dimExprs, context); resultExprs.push_back(linearizedExpr); } return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(), @@ -349,6 +358,23 @@ return nullptr; } +/// Check if the reshape operation is only expansion into/collapsing of +/// unit-dimension. +static bool isUnitDimExpansionOnly(ArrayRef expandedShape, + ArrayRef reassociation) { + for (auto &map : reassociation) { + unsigned numUnitDims = 0; + for (AffineExpr expr : map.getResults()) { + unsigned position = expr.cast().getPosition(); + if (expandedShape[position] == 1) + numUnitDims++; + } + if (numUnitDims != map.getNumResults() - 1) + return false; + } + return true; +} + /// Conditions for folding a generic/indexed-generic operation with a reshape op /// by expanding the iteration space dimensionality for tensor operations. These /// are preconditions assumed by `foldReshapeByDimExpansion` which implements @@ -776,7 +802,7 @@ /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... } /// ins(%arg0, %arg1 : tensor, tensor) ... /// -> tensor -template +template struct FoldProducerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -792,7 +818,10 @@ if (!reshapeOp || !isTensorReshapeOpFoldableByLinearization( reshapeOp, linalgOp.getInputIndexingMap(operand.index()), - /*asProducer =*/true)) + /*asProducer =*/true) || + (foldUnitDimReshapesOnly && + !isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps()))) continue; // Compute the fused operands list, @@ -858,7 +887,9 @@ // - All constraints of fusing with reshape by expansion are met. if (reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank() || - !isFusableWithReshapeByDimExpansion(linalgOp, operand.index())) + !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()) || + isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps())) continue; Optional> replacementValues = @@ -877,6 +908,7 @@ /// Pattern to fold tensor_reshape op with its producer. The corresponding index /// map in the consumer needs to be modified to linearize the folded dimension. +template struct FoldConsumerReshapeOpByLinearization : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -888,7 +920,11 @@ !isa(producer.getOperation()) || !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 || !isTensorReshapeOpFoldableByLinearization( - reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false)) + reshapeOp, producer.getOutputIndexingMap(0), + /*asProducer =*/false) || + (foldUnitDimReshapesOnly && + !isUnitDimExpansionOnly(reshapeOp.getSrcType().getShape(), + reshapeOp.getReassociationMaps()))) return failure(); // The indexing_maps for the operands of the fused operation are same as // those for the operands of the producer. @@ -949,7 +985,10 @@ return failure(); LinalgOp producer = reshapeOp.src().getDefiningOp(); if (!producer || producer.getNumOutputs() != 1 || - !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs())) + !isFusableWithReshapeByDimExpansion(producer, + producer.getNumInputs()) || + isUnitDimExpansionOnly(reshapeOp.getResultType().getShape(), + reshapeOp.getReassociationMaps())) return failure(); Optional> replacementValues = fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(), @@ -1098,9 +1137,16 @@ void mlir::populateFoldReshapeOpsByLinearizationPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert, - FoldProducerReshapeOpByLinearization, - FoldConsumerReshapeOpByLinearization>(context); + patterns.insert, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>(context); +} + +void mlir::populateFoldUnitDimsReshapeOpsByLinearizationPatterns( + MLIRContext *context, OwningRewritePatternList &patterns) { + patterns.insert, + FoldProducerReshapeOpByLinearization, + FoldConsumerReshapeOpByLinearization>(context); } void mlir::populateFoldReshapeOpsByExpansionPatterns( diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -331,3 +331,26 @@ ] : tensor<2x1x1xf32> into tensor<2x1xf32> return %1 : tensor<2x1xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d1)> +#map4 = affine_map<(d0, d1, d2) -> (d2)> +func @fold_unit_dim_tensor_reshape_op(%arg0 : tensor<5xf32>) -> tensor<2x5xf32> +{ + %1 = linalg.init_tensor [1, 2, 5] : tensor<1x2x5xf32> + %2 = linalg.generic {i64, indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor<5xf32>) outs(%1 : tensor<1x2x5xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + linalg.yield %arg1 : f32 + } -> tensor<1x2x5xf32> + %3 = linalg.tensor_reshape %2 [#map3, #map4] + : tensor<1x2x5xf32> into tensor<2x5xf32> + return %3 : tensor<2x5xf32> +} +// CHECK-LABEL: func @fold_unit_dim_tensor_reshape_op +// CHECK: %[[RESULT:.+]] = linalg.generic +// CHECK: return %[[RESULT]] diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -188,42 +188,6 @@ // ----- -func @scalar_reshape( - %arg0 : tensor<1x10xf32>, %arg1 : tensor<1xf32>) -> tensor<1x10xf32> -{ - %0 = linalg.tensor_reshape %arg1 [] : tensor<1xf32> into tensor - %1 = linalg.init_tensor [10] : tensor<10xf32> - %2 = linalg.generic - {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%0 : tensor) - outs(%1 : tensor<10xf32>) { - ^bb0(%arg2: f32, %s: f32): // no predecessors - linalg.yield %arg2 : f32 - } -> tensor<10xf32> - %3 = linalg.tensor_reshape %2 [affine_map<(d0, d1) -> (d0, d1)>] - : tensor<10xf32> into tensor<1x10xf32> - return %3 : tensor<1x10xf32> -} - -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0, d1)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> ()> -// CHECK: func @scalar_reshape -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor<1x10xf32> -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor<1xf32> -// CHECK: %[[T0:.+]] = linalg.tensor_reshape %[[ARG1]] [] -// CHECK-SAME: tensor<1xf32> into tensor -// CHECK: %[[T1:.+]] = linalg.init_tensor [10] -// CHECK: %[[T2:.+]] = linalg.tensor_reshape %[[T1]] [#[[MAP0]]] -// CHECK: %[[T3:.+]] = linalg.generic -// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP0]]] -// CHECK-SAME: iterator_types = ["parallel", "parallel"] -// CHECK-SAME: ins(%[[T0]] : tensor) -// CHECK-SAME: outs(%[[T2]] : tensor<1x10xf32>) -// CHECK: return %[[T3]] : tensor<1x10xf32> - -// ----- - #map0 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> #map1 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> func @indexed_generic_op_reshape_producer_fusion(%arg0 : tensor, @@ -336,7 +300,7 @@ %5 = addi %3, %4 : i32 %6 = index_cast %arg2 : index to i32 %7 = addi %5, %6 : i32 - linalg.yield %7 : i32 + linalg.yield %7 : i32 } -> tensor<6x4x210xi32> %d = linalg.tensor_reshape %c [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>, @@ -493,3 +457,77 @@ // CHECK-SAME: ins(%[[T0]], %[[T1]] : tensor, tensor) // CHECK-SAME: outs(%[[T2]] : tensor) // CHECK: return %[[T3]] : tensor + +// ----- + +func @unit_dim_reshape_expansion(%arg0 : tensor<1x5xf32>) -> tensor<5x5xf32> { + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1) -> (d0, d1)>] : tensor<1x5xf32> into tensor<5xf32> + %1 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %2 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%0 : tensor<5xf32>) outs(%1 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + return %2 : tensor<5x5xf32> +} +// CHECK: func @unit_dim_reshape_expansion +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic + +// ----- + +func @unit_dim_reshape_collapse(%arg0 : tensor<5xf32>) -> tensor<5x1x5xf32> { + %0 = linalg.init_tensor [5, 5] : tensor<5x5xf32> + %1 = linalg.generic + {indexing_maps = [affine_map<(d0, d1) -> (d0)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<5xf32>) outs(%0 : tensor<5x5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } -> tensor<5x5xf32> + %2 = linalg.tensor_reshape %1 + [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] + : tensor<5x5xf32> into tensor<5x1x5xf32> + return %2 : tensor<5x1x5xf32> +} +// CHECK: func @unit_dim_reshape_collapse +// CHECK: linalg.init_tensor +// CHECK: linalg.generic +// CHECK: linalg.tensor_reshape + +// ----- + +func @unit_dim_reshape_expansion_full + (%arg0 : tensor<1x?x1x2x1x4xf32>, %arg1 : tensor) + -> tensor { + %c1 = constant 1 : index + %0 = linalg.tensor_reshape %arg0 + [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d5)>] + : tensor<1x?x1x2x1x4xf32> into tensor + %1 = dim %arg0, %c1 : tensor<1x?x1x2x1x4xf32> + %2 = linalg.init_tensor [%1, 2, 4] : tensor + %3 = linalg.generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): // no predecessors + %4 = mulf %arg2, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK: func @unit_dim_reshape_expansion_full +// CHECK-DAG: linalg.tensor_reshape +// CHECK-DAG: linalg.init_tensor +// CHECK: linalg.generic