diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -383,9 +383,12 @@ ArrayRef sliceParams); /// Determine which dimensions are linearized by a `tensor.collapse_shape` op by -/// inspecting its reassociation indices. +/// inspecting its reassociation indices and the shape of the input tensor. An +/// output index is considered linearized if it corresponds to multiple non-unit +/// input dimensions. llvm::SmallBitVector -getLinearizedDimensions(ArrayRef reassociationIndices); +getLinearizedDimensions(ArrayRef collapseShapeInputShape, + ArrayRef reassociationIndices); /// Given the parameters for both operations in a `CollapseShape->ExtractSlice` /// chain and reified source and result shapes of the CollapseShapeOp, this @@ -427,7 +430,8 @@ collapseShapeInputShape(collapseShapeInputShape), collapseShapeOutputShape(collapseShapeOutputShape), sliceParams(extractSliceParams), - linearizedDimensions(getLinearizedDimensions(reassociationIndices)), + linearizedDimensions(getLinearizedDimensions(collapseShapeInputShape, + reassociationIndices)), slicedDimensions(getSlicedDimensions(collapseShapeOutputShape, extractSliceParams)) {} diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -26,8 +26,8 @@ using namespace mlir::tensor; /// Get the dimension size of a value of RankedTensor type at the -OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, - int64_t dimIdx) { +static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, + Value rankedTensor, int64_t dimIdx) { RankedTensorType tensorType = rankedTensor.getType().cast(); if (!tensorType.isDynamicDim(dimIdx)) { return b.getIndexAttr(tensorType.getDimSize(dimIdx)); @@ -116,18 +116,16 @@ SmallVector reassociationIndices = op.getReassociationIndices(); + SmallVector collapseShapeInputShape = + getShapeDimSizes(b, op.getLoc(), op.getSrc()); + // Determine which of the CollapseShapeOp's result dimensions are sliced // and/or linearized. llvm::SmallBitVector linearizedDimensions = - getLinearizedDimensions(reassociationIndices); + getLinearizedDimensions(collapseShapeInputShape, reassociationIndices); llvm::SmallBitVector slicedDimensions = getSlicedDimensions(collapseShapeOutputShape, sliceParams); - auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc()); - - SmallVector srcShape = - getShapeDimSizes(b, op->getLoc(), op.getSrc()); - SmallVector tileSizes; for (unsigned i = 0; i < sliceParams.size(); i++) { if (slicedDimensions[i] && linearizedDimensions[i]) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -289,11 +289,30 @@ return mask; } +/// Returns the index of the only non-unit dimension among `indices` of `shape`, +/// if such a dimension exists. Otherwise, return none. +static Optional getUniqueNonUnitDim(ArrayRef indices, + ArrayRef shape) { + // Return false if more than one of the dimensions in this group are not 1. + Optional dimIndex = None; + for (int64_t idx : indices) { + if (!isConstantIntValue(shape[idx], 1)) { + if (dimIndex != None) + return None; + dimIndex = idx; + } + } + return dimIndex; +} + llvm::SmallBitVector mlir::getLinearizedDimensions( + ArrayRef collapseShapeInputShape, ArrayRef reassociationIndices) { llvm::SmallBitVector result(reassociationIndices.size()); - for (const auto &it : llvm::enumerate(reassociationIndices)) - result[it.index()] = it.value().size() > 1; + for (const auto &it : llvm::enumerate(reassociationIndices)) { + if (!getUniqueNonUnitDim(it.value(), collapseShapeInputShape)) + result[it.index()] = it.value().size() > 1; + } return result; } @@ -329,7 +348,20 @@ continue; } - // Case 3: A single index, but it may be sliced. + // Case 3a: A collapsing of unit dims with at most one non-unit dim. We + // don't consider this a linearization. + if (Optional uniqueNonUnitDim = + getUniqueNonUnitDim(it.value(), collapseShapeInputShape)) { + for (int64_t dim : it.value()) { + if (dim == *uniqueNonUnitDim) + offsetsSizesAndStrides.push_back(sliceParams[it.index()]); + else + offsetsSizesAndStrides.push_back(Range{zeroAttr, oneAttr, oneAttr}); + } + continue; + } + + // Case 3b: A single non-linearized, but possibly sliced, dimension. offsetsSizesAndStrides.push_back(sliceParams[it.index()]); } return offsetsSizesAndStrides; diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir --- a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -177,3 +177,44 @@ // CHECK: return %[[res]] return %slice : tensor<330x?xf32> } + +// ----- + +// The below tests verify that a dimension which is the result of collapsing at +// most one non-unit dim is not considered a linearization. + +// CHECK: @collapse_and_slice_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @collapse_and_slice_unit_dim(%input: tensor<1x11x100xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1], [2]] : tensor<1x11x100xf32> into tensor<11x100xf32> + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor<11x100xf32> to tensor + // CHECK-NOT: scf.for + // CHECK: %[[init:.+]] = tensor.empty(%[[arg2]]) + // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, %[[arg1]], 0] [1, %[[arg2]], 100] [1, 1, 1] + // CHECK-SAME: tensor<1x11x100xf32> to tensor<1x?x100xf32> + // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1], [2]] + // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]][0, 0] [%[[arg2]], 100] [1, 1] + return %slice : tensor +} + +// CHECK: @collapse_and_slice_multiple_unit_dim_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: index, %[[arg2:.+]]: index +func.func @collapse_and_slice_multiple_unit_dim_dynamic(%input: tensor<1x?x1x100xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x?x1x100xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 100] [1, 1] : tensor to tensor + // CHECK-NOT: scf.for + // CHECK: %[[init:.+]] = tensor.empty(%[[arg2]]) + // CHECK: %[[e:.+]] = tensor.extract_slice %[[arg0]][0, %[[arg1]], 0, 0] [1, %[[arg2]], 1, 100] [1, 1, 1, 1] + // CHECK-SAME: tensor<1x?x1x100xf32> to tensor<1x?x1x100xf32> + // CHECK: %[[c:.+]] = tensor.collapse_shape %[[e]] {{\[}}[0, 1, 2], [3]] + // CHECK: %[[res:.+]] = tensor.insert_slice %[[c]] into %[[init]][0, 0] [%[[arg2]], 100] [1, 1] + return %slice : tensor +} + +// Edge case where all collapsed dims are unit dims +// CHECK: @collapse_and_slice_multiple_all_unit_dim(%[[arg0:.+]]: tensor<{{.*}}>) +func.func @collapse_and_slice_multiple_all_unit_dim(%input: tensor<1x1x1x100xf32>) -> tensor<1x100xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [1, 100] [1, 1] : tensor<1x100xf32> to tensor<1x100xf32> + return %slice : tensor<1x100xf32> + // CHECK: %[[collapse:.+]] = tensor.collapse_shape %[[arg0]] {{\[}}[0, 1, 2], [3]] : tensor<1x1x1x100xf32> into tensor<1x100xf32> + // CHECK: return %[[collapse]] +}