diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -166,47 +166,19 @@ /// 2) if a dimension in the collaped type is dynamic, one and only one of the /// corresponding dimensions in the expanded type should be dynamic. This /// rule is only needed with reshape operations that are expanding. +LogicalResult reshapeLikeShapesAreCompatible( + function_ref emitError, + ArrayRef collapsedShape, ArrayRef expandedShape, + ArrayRef reassociationMaps, bool isExpandingReshape); + template static LogicalResult verifyReshapeLikeShapes(OpTy op, ShapedType collapsedType, ShapedType expandedType, bool isExpandingReshape) { - ArrayRef collapsedShape = collapsedType.getShape(); - ArrayRef expandedShape = expandedType.getShape(); - unsigned expandedDimStart = 0; - for (auto map : llvm::enumerate(op.getReassociationMaps())) { - Optional dynamicShape; - int64_t linearizedStaticShape = 1; - for (auto dim : llvm::enumerate(expandedShape.slice( - expandedDimStart, map.value().getNumResults()))) { - if (ShapedType::isDynamic(dim.value())) { - if (isExpandingReshape && dynamicShape) { - return op->emitOpError("invalid to have a single dimension (") - << map.index() << ") expanded into multiple dynamic dims (" - << expandedDimStart + dynamicShape.getValue() << "," - << expandedDimStart + dim.index() << ")"; - } - dynamicShape = dim.index(); - } else { - linearizedStaticShape *= dim.value(); - } - } - if (dynamicShape) { - if (!ShapedType::isDynamic(collapsedShape[map.index()])) { - return op->emitOpError("expected dimension ") - << map.index() - << " of collapsed type to be dynamic since one or more of the " - "corresponding dimensions in the expanded type is dynamic"; - } - } else { - if (collapsedShape[map.index()] != linearizedStaticShape) { - return op->emitOpError("expected dimension ") - << map.index() << " of collapsed type to be static value of " - << linearizedStaticShape << " "; - } - } - expandedDimStart += map.value().getNumResults(); - } - return success(); + return reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { return op->emitOpError(msg); }, + collapsedType.getShape(), expandedType.getShape(), + op.getReassociationIndices(), isExpandingReshape); } /// Pattern to collapse producer/consumer reshape ops that are both collapsing diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -608,31 +608,6 @@ LogicalResult isGenericOpExpandable(GenericOp genericOp, const ExpansionInfo &expansionInfo, PatternRewriter &rewriter) { - // Current reshape only supports expansion of a dynamic dim when only one of - // the expanded dims are dynamic. - for (const auto &originalShape : - llvm::enumerate(expansionInfo.getOriginalShape())) - if (ShapedType::isDynamic(originalShape.value())) { - // All but one of the expanded dims must be static. - bool foundDynamicExpandedDim = false; - for (auto expandedShape : - expansionInfo.getExpandedShapeOfDim(originalShape.index())) { - if (ShapedType::isDynamic(expandedShape)) { - if (foundDynamicExpandedDim) { - return rewriter.notifyMatchFailure( - genericOp, - "cannot expanded dynamic dims into multiple dynamic dims"); - } - foundDynamicExpandedDim = true; - } - } - if (!foundDynamicExpandedDim) { - return rewriter.notifyMatchFailure( - genericOp, "dynamic dim expansion needs at least one dynamic dim " - "in result shape"); - } - } - if (!genericOp.hasIndexSemantics()) return success(); for (unsigned i : llvm::seq(0, expansionInfo.getOrigOpNumDims())) { @@ -793,13 +768,21 @@ } if (genericOp.isInputTensor(opOperand)) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOperandType = - getExpandedType(opOperand->get().getType().cast(), - indexingMap, expansionInfo); + getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOperandType != opOperand->get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(genericOp, msg); + }, + opOperandType.getShape(), expandedOperandType.getShape(), + reassociation, + /*isExpandingReshape=*/true))) + return llvm::None; expandedOpOperands.push_back(rewriter.create( genericOp.getLoc(), expandedOperandType, opOperand->get(), reassociation)); @@ -813,12 +796,20 @@ SmallVector outputs; for (OpOperand *opOperand : genericOp.getOutputOperands()) { AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); + auto opOperandType = opOperand->get().getType().cast(); RankedTensorType expandedOutputType = - getExpandedType(opOperand->get().getType().cast(), - indexingMap, expansionInfo); + getExpandedType(opOperandType, indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); + if (failed(reshapeLikeShapesAreCompatible( + [&](const Twine &msg) { + return rewriter.notifyMatchFailure(genericOp, msg); + }, + opOperandType.getShape(), expandedOutputType.getShape(), + reassociation, + /*isExpandingReshape=*/true))) + return llvm::None; outputs.push_back(rewriter.create( genericOp.getLoc(), expandedOutputType, opOperand->get(), reassociation)); diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -276,3 +276,45 @@ } return true; } + +LogicalResult mlir::reshapeLikeShapesAreCompatible( + function_ref emitError, + ArrayRef collapsedShape, ArrayRef expandedShape, + ArrayRef reassociationMaps, bool isExpandingReshape) { + unsigned expandedDimStart = 0; + for (const auto &map : llvm::enumerate(reassociationMaps)) { + Optional dynamicShape; + int64_t linearizedStaticShape = 1; + for (const auto &dim : llvm::enumerate( + expandedShape.slice(expandedDimStart, map.value().size()))) { + if (ShapedType::isDynamic(dim.value())) { + if (isExpandingReshape && dynamicShape) { + return emitError("invalid to have a single dimension (" + + Twine(map.index()) + + ") expanded into multiple dynamic dims (" + + Twine(expandedDimStart + dynamicShape.getValue()) + + "," + Twine(expandedDimStart + dim.index()) + ")"); + } + dynamicShape = dim.index(); + } else { + linearizedStaticShape *= dim.value(); + } + } + if (dynamicShape) { + if (!ShapedType::isDynamic(collapsedShape[map.index()])) { + return emitError( + "expected dimension " + Twine(map.index()) + + " of collapsed type to be dynamic since one or more of the " + "corresponding dimensions in the expanded type is dynamic"); + } + } else { + if (collapsedShape[map.index()] != linearizedStaticShape) { + return emitError("expected dimension " + Twine(map.index()) + + " of collapsed type to be static value of " + + Twine(linearizedStaticShape)); + } + } + expandedDimStart += map.value().size(); + } + return success(); +} diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -530,3 +530,30 @@ // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[RESHAPE]] : tensor) // CHECK: return %[[GENERIC]] + +// ----- + +func @no_fuse_mismatched_dynamism(%arg0: tensor<1x1xi64>, %arg1: tensor) -> tensor<1xi64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64> + %1 = linalg.init_tensor [1] : tensor<1xi64> + %2 = linalg.generic + {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%0, %arg1 : tensor<1xi64>, tensor) + outs(%1 : tensor<1xi64>) { + ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): // no predecessors + %3 = arith.addi %arg4, %arg5 : i64 + linalg.yield %3 : i64 + } -> tensor<1xi64> + return %2 : tensor<1xi64> +} + +// CHECK: func @no_fuse_mismatched_dynamism +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xi64> +// CHECK-SAME: %[[ARG1:.+]]: tensor +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: ins(%[[RESHAPE]], %[[ARG1]] : tensor<1xi64>, tensor) +// CHECK: return %[[GENERIC]]