diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -9,6 +9,7 @@ // This file implements the linalg dialect Fusion on tensors operations pass. // //===----------------------------------------------------------------------===// +#include #include #include "PassDetail.h" @@ -650,6 +651,20 @@ return RankedTensorType::get(expandedShape, originalType.getElementType()); } +static RankedTensorType getCollapsedType(RankedTensorType originalType, + AffineMap indexingMap, + const ExpansionInfo &expansionInfo) { + SmallVector expandedShape; + for (AffineExpr expr : indexingMap.getResults()) { + unsigned dim = expr.cast().getPosition(); + auto dimExpansion = expansionInfo.getExpandedShapeOfDim(dim); + expandedShape.push_back(std::accumulate( + dimExpansion.begin(), dimExpansion.end(), 1, + [](int64_t acc, int64_t i) { return std::max(acc * i, -1); })); + } + return RankedTensorType::get(expandedShape, originalType.getElementType()); +} + /// Returns the reassociation maps to use in the `tensor.expand_shape` /// operation to convert the operands of the original operation to operands of /// the expanded operation. The same method is used to compute the @@ -748,6 +763,13 @@ return getIndexingMapInExpandedOp(rewriter, m, expansionInfo); })); + // The input type to the generic can be more dynamic than the input of the + // reshape. Insert a cast if necessary to make them match. + auto maybeCast = [&](Value op, Type t) { + return op.getType() != t + ? rewriter.create(genericOp.getLoc(), t, op) + : op; + }; SmallVector expandedOpOperands; expandedOpOperands.reserve(genericOp.getNumInputs()); for (OpOperand *opOperand : genericOp.getInputOperands()) { @@ -761,13 +783,16 @@ RankedTensorType expandedOperandType = getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); + RankedTensorType collapsedOperandType = + getCollapsedType(opOperand->get().getType().cast(), + indexingMap, expansionInfo); if (expandedOperandType != opOperand->get().getType()) { // Reshape the operand to get the right type. SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); expandedOpOperands.push_back(rewriter.create( - genericOp.getLoc(), expandedOperandType, opOperand->get(), - reassociation)); + genericOp.getLoc(), expandedOperandType, + maybeCast(opOperand->get(), collapsedOperandType), reassociation)); continue; } } @@ -781,12 +806,15 @@ RankedTensorType expandedOutputType = getExpandedType(opOperand->get().getType().cast(), indexingMap, expansionInfo); + RankedTensorType collapsedOperandType = + getCollapsedType(opOperand->get().getType().cast(), + indexingMap, expansionInfo); if (expandedOutputType != opOperand->get().getType()) { SmallVector reassociation = getReassociationForExpansion(indexingMap, expansionInfo); outputs.push_back(rewriter.create( - genericOp.getLoc(), expandedOutputType, opOperand->get(), - reassociation)); + genericOp.getLoc(), expandedOutputType, + maybeCast(opOperand->get(), collapsedOperandType), reassociation)); } } diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -507,3 +507,26 @@ // FOLDUNITDIM-SAME: ins(%[[ARG0]], %[[RESHAPE]] : tensor<1x?x1x2x1x4xf32>, tensor<1x?x1x2x1x4xf32>) // FOLDUNITDIM-SAME: outs(%{{.+}} : tensor<1x?x1x2x1x4xf32>) +// ----- + +func @result_less_dynamic(%arg0: tensor<1x1xi64>, %arg1: tensor) -> tensor<1xi64> { + %0 = tensor.collapse_shape %arg0 [[0, 1]] : tensor<1x1xi64> into tensor<1xi64> + %1 = linalg.init_tensor [1] : tensor<1xi64> + %2 = linalg.generic + {indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%0, %arg1 : tensor<1xi64>, tensor) + outs(%1 : tensor<1xi64>) { + ^bb0(%arg4: i64, %arg5: i64, %arg6: i64): // no predecessors + %3 = arith.addi %arg4, %arg5 : i64 + linalg.yield %3 : i64 + } -> tensor<1xi64> + return %2 : tensor<1xi64> +} + +// CHECK: func @result_less_dynamic +// CHECK-SAME: %[[ARG1:[^ ]+]]: tensor +// CHECK: %[[CAST:.+]] = tensor.cast %[[ARG1]] : tensor to tensor<1xi64> +// CHECK: tensor.expand_shape %[[CAST]] {{\[}}[0, 1]] : tensor<1xi64> into tensor<1x1xi64>