diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -118,11 +118,12 @@ /// dimension is statically known, or -1 otherwise. SmallVector getStaticShape(LinalgOp linalgOp); -/// Returns the statically-known loop ranges of the `linalgOp`. Applies the -/// inverse of the concatenated indexing maps to the result of `getStaticShape`. -/// Returns None if inverting the concatenated indexing map fails. Returns -1 +/// Returns the statically-known loop ranges of the `linalgOp`. Composes +/// `linalgOp.getShapesToLoopsMap()` with the result of `getStaticShape`. +/// Returns None if `linalgOp.getShapesToLoopsMap()` fails. Returns -1 /// for non-statically-known loop ranges. Optional> getStaticLoopRanges(LinalgOp linalgOp); + /// Apply the permutation defined by `permutation` to `inVec`. /// Element `i` in `inVec` is mapped to location `j = permutation[i]`. /// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a permutation vector diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -415,7 +415,8 @@ // - All the indexing maps for operands in linalgOp are projected // permutations. // - The indexing map at the position representing the fused tensor is a - // permutation. + // projected permutation. + // - The fused tensor is not a scalar. // - All the loops in linalgOp are parallel loops. return isa(linalgOp.getOperation()) && linalgOp.hasTensorSemantics() && @@ -426,7 +427,8 @@ .getValue() .isProjectedPermutation(); }) && - linalgOp.getIndexingMap(fusedTensorIndex).isPermutation() && + linalgOp.getIndexingMap(fusedTensorIndex).isProjectedPermutation() && + linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 && llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) { return attr.cast().getValue() == getParallelIteratorTypeName(); @@ -447,8 +449,6 @@ reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank(); RankedTensorType expandedType = isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType(); - RankedTensorType foldedType = - isExpanding ? reshapeOp.getSrcType() : reshapeOp.getResultType(); AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex); // The reshape is folding/expanding consecutive dimensions. Given the indexing @@ -456,9 +456,15 @@ // the original op is expanded into. Also record the shape of the expanded // dimensions. ArrayRef expandedShape = expandedType.getShape(); - SmallVector numFoldedDims(foldedType.getRank(), 0); + Optional> origOpLoopRange = + getStaticLoopRanges(linalgOp); + if (!origOpLoopRange) { + linalgOp.emitError("unable to find loop range for operation"); + return llvm::None; + } + SmallVector numFoldedDims(fusedIndexMap.getNumDims(), 1); SmallVector, 4> expandedDimsShape( - foldedType.getRank()); + fusedIndexMap.getNumDims()); auto reassociationMaps = reshapeOp.getReassociationMaps(); for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) { unsigned pos = resultExpr.value().cast().getPosition(); @@ -468,6 +474,10 @@ expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]); expandedDimsShape[pos].assign(shape.begin(), shape.end()); } + // The remaining dimensions remain the same. + for (unsigned i : llvm::seq(0, fusedIndexMap.getNumDims())) + if (expandedDimsShape[i].empty()) + expandedDimsShape[i] = {(*origOpLoopRange)[i]}; if (isa(linalgOp.getOperation())) { // For indexed generic op, the region contains arguments that represent the @@ -477,6 +487,8 @@ // front) are statically know. For dynamic case, we would need shape // information on these dimensions to get these. for (auto &expandedShape : expandedDimsShape) { + if (expandedShape.size() == 1) + continue; for (int64_t expandedDimShape : llvm::make_range( std::next(expandedShape.begin()), expandedShape.end())) { if (ShapedType::isDynamic(expandedDimShape)) { diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -104,13 +104,18 @@ auto shape = v.getType().cast().getShape(); res.append(shape.begin(), shape.end()); } + if (linalgOp.getNumInitTensors()) + return res; + for (Value v : linalgOp.getOperation()->getResults()) { + auto shape = v.getType().cast().getShape(); + res.append(shape.begin(), shape.end()); + } return res; } Optional> getStaticLoopRanges(LinalgOp linalgOp) { SmallVector viewSizes = getStaticShape(linalgOp); - AffineMap invertedMap = - inversePermutation(concatAffineMaps(linalgOp.getIndexingMaps())); + AffineMap invertedMap = linalgOp.getShapesToLoopsMap(); if (!invertedMap) return {}; return invertedMap.compose(viewSizes); diff --git a/mlir/test/Dialect/Linalg/reshape_fusion.mlir b/mlir/test/Dialect/Linalg/reshape_fusion.mlir --- a/mlir/test/Dialect/Linalg/reshape_fusion.mlir +++ b/mlir/test/Dialect/Linalg/reshape_fusion.mlir @@ -344,3 +344,56 @@ // CHECK: %[[T9:.+]] = addi %[[T7]], %[[T8]] // CHECK: %[[T10:.+]] = index_cast %[[ARG7]] // CHECK: %[[T11:.+]] = addi %[[T9]], %[[T10]] + +// ----- + +func @reshape_as_producer_projected_permutation + (%arg0 : tensor<33x8x?xi32>) -> tensor<264x?x4xi32> { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d2)>] + : tensor<33x8x?xi32> into tensor<264x?xi32> + %1 = linalg.indexed_generic + {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1)>, + affine_map<(d0, d1, d2) -> (d0, d1, d2)>], + iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<264x?xi32>) { + ^bb0(%arg1: index, %arg2: index, %arg3: index, %arg4: i32): // no predecessors + %2 = index_cast %arg1 : index to i32 + %3 = addi %arg4, %2 : i32 + %4 = index_cast %arg2 : index to i32 + %5 = addi %3, %4 : i32 + %6 = index_cast %arg3 : index to i32 + %7 = addi %5, %6 : i32 + linalg.yield %7 : i32 + } -> tensor<264x?x4xi32> + return %1 : tensor<264x?x4xi32> +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0 * 8 + d1)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d2)> +// CHECK-DAG: #[[MAP5:.+]] = affine_map<(d0, d1, d2, d3) -> (d3)> +// CHECK: @reshape_as_producer_projected_permutation +// CHECK-SAME: %[[ARG0:.+]]: tensor<33x8x?xi32> +// CHECK: %[[RES:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: ins(%[[ARG0]] : tensor<33x8x?xi32>) +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index, +// CHECK-SAME: %[[ARG5:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[T0:.+]] = affine.apply #[[MAP2]](%[[ARG1]], %[[ARG2]]) +// CHECK: %[[T1:.+]] = index_cast %[[T0]] : index to i32 +// CHECK: %[[T2:.+]] = addi %[[ARG5]], %[[T1]] : i32 +// CHECK: %[[T3:.+]] = index_cast %[[ARG3]] : index to i32 +// CHECK: %[[T4:.+]] = addi %[[T2]], %[[T3]] : i32 +// CHECK: %[[T5:.+]] = index_cast %[[ARG4]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T4]], %[[T5]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 +// CHECK: %[[RES2:.+]] = linalg.tensor_reshape %[[RES]] +// CHECK-SAME: [#[[MAP3]], #[[MAP4]], #[[MAP5]]] +// CHECK-SAME: : tensor<33x8x?x4xi32> into tensor<264x?x4xi32> +// CHECK: return %[[RES2]] : tensor<264x?x4xi32>