diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -524,6 +524,7 @@ LogicalResult compute(LinalgOp linalgOp, OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, + ArrayRef collapsedShape, PatternRewriter &rewriter); unsigned getOrigOpNumDims() const { return reassociation.size(); } unsigned getExpandedOpNumDims() const { return expandedOpNumDims; } @@ -533,6 +534,7 @@ ArrayRef getExpandedShapeOfDim(unsigned i) const { return expandedShapeMap[i]; } + ArrayRef getOriginalShape() const { return originalLoopExtent; } private: /// Reassociation from the dimensions in the original operation to the @@ -541,6 +543,8 @@ /// Mapping from extent of loops in the original operation, to the extent of /// loops in the expanded operation. SmallVector> expandedShapeMap; + /// Extent of the loop in the original operation. + SmallVector originalLoopExtent; unsigned expandedOpNumDims; }; } // namespace @@ -549,6 +553,7 @@ OpOperand *fusableOpOperand, ArrayRef reassociationMaps, ArrayRef expandedShape, + ArrayRef collapsedShape, PatternRewriter &rewriter) { if (reassociationMaps.empty()) return failure(); @@ -558,6 +563,8 @@ linalgOp.getStaticLoopRanges(); if (!originalLoopRange) return rewriter.notifyMatchFailure(linalgOp, "unable to find loop range"); + originalLoopExtent.assign(originalLoopRange->begin(), + originalLoopRange->end()); reassociation.clear(); expandedShapeMap.clear(); @@ -576,7 +583,7 @@ // The remaining dimensions remain the same. for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) if (expandedShapeMap[i].empty()) - expandedShapeMap[i] = {(*originalLoopRange)[i]}; + expandedShapeMap[i] = {originalLoopExtent[i]}; // Compute reassociation map from the original op to the expanded op. unsigned sum = 0; @@ -601,6 +608,30 @@ LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter) { + // Current reshape only supports expansion of a dynamic dim when only one of + // the expanded dims are dynamic. + for (auto originalShape : llvm::enumerate(expansionInfo.getOriginalShape())) + if (ShapedType::isDynamic(originalShape.value())) { + // All but one of the expanded dims must be static. + bool foundDynamicExpandedDim = false; + for (auto expandedShape : + expansionInfo.getExpandedShapeOfDim(originalShape.index())) { + if (ShapedType::isDynamic(expandedShape)) { + if (foundDynamicExpandedDim) { + return rewriter.notifyMatchFailure( + genericOp, + "cannot expanded dynamic dims into multiple dynamic dims"); + } + foundDynamicExpandedDim = true; + } + } + if (!foundDynamicExpandedDim) { + return rewriter.notifyMatchFailure( + genericOp, "dynamic dim expansion needs at least one dynamic dim " + "in result shape"); + } + } + if (!genericOp.hasIndexSemantics()) return success(); for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { @@ -731,13 +762,16 @@ RankedTensorType expandedType = isExpanding ? expandingReshapeOp.getResultType() : collapsingReshapeOp.getSrcType(); + RankedTensorType collapsedType = isExpanding + ? expandingReshapeOp.getSrcType() + : collapsingReshapeOp.getResultType(); ExpansionInfo expansionInfo; if (failed(expansionInfo.compute( genericOp, fusableOpOperand, isExpanding ? expandingReshapeOp.getReassociationMaps() : collapsingReshapeOp.getReassociationMaps(), - expandedType.getShape(), rewriter))) + expandedType.getShape(), collapsedType.getShape(), rewriter))) return llvm::None; if (failed(isGenericOpExpandable(genericOp, expansionInfo, rewriter))) diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -507,3 +507,26 @@ // FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) // FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>) +// ----- + +func @no_fuse_dynamic_dims(%arg0: tensor) -> tensor { + %c0 = arith.constant 0 : index + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor into tensor + %1 = tensor.dim %0, %c0 : tensor + %2 = linalg.init_tensor [%1] : tensor + %3 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%0 : tensor) outs(%2 : tensor) { + ^bb0(%arg1 : f32, %arg2: f32): + %4 = arith.addf %arg1, %arg1 : f32 + linalg.yield %4 : f32 + } -> tensor + return %3 : tensor +} +// CHECK: func @no_fuse_dynamic_dims +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]] : tensor) +// CHECK: return %[[GENERIC]]