diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -110,6 +110,47 @@ return success(); } +/// Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or +/// rank-extending tensor.insert_slice op. +static llvm::SmallBitVector getDroppedDims(ArrayRef reducedShape, + ArrayRef mixedSizes) { + llvm::SmallBitVector droppedDims(mixedSizes.size()); + int64_t shapePos = 0; + + for (const auto &size : enumerate(mixedSizes)) { + // Rank-reduced dims must have a static unit dimension. + bool isStaticUnitSize = + size.value().is() && + size.value().get().cast().getInt() == 1; + + if (shapePos == reducedShape.size()) { + // There are no more dims in the reduced shape. All remaining sizes must + // be rank-reduced dims. + assert(isStaticUnitSize && "expected unit dim"); + droppedDims.set(size.index()); + continue; + } + + // Dim is preserved if the size is not a static 1. + if (!isStaticUnitSize) { + ++shapePos; + continue; + } + + // Dim is preserved if the reduced shape dim is also 1. + if (reducedShape[shapePos] == 1) { + ++shapePos; + continue; + } + + // Otherwise: Dim is dropped. + droppedDims.set(size.index()); + } + + assert(shapePos == reducedShape.size() && "dimension mismatch"); + return droppedDims; +} + //===----------------------------------------------------------------------===// // CastOp //===----------------------------------------------------------------------===// @@ -1740,23 +1781,7 @@ } llvm::SmallBitVector ExtractSliceOp::getDroppedDims() { - ArrayRef resultShape = getType().getShape(); - SmallVector mixedSizes = getMixedSizes(); - llvm::SmallBitVector droppedDims(mixedSizes.size()); - unsigned shapePos = 0; - for (const auto &size : enumerate(mixedSizes)) { - std::optional sizeVal = getConstantIntValue(size.value()); - // If the size is not 1, or if the current matched dimension of the result - // is the same static shape as the size value (which is 1), then the - // dimension is preserved. - if (!sizeVal || *sizeVal != 1 || - (shapePos < resultShape.size() && resultShape[shapePos] == 1)) { - shapePos++; - continue; - } - droppedDims.set(size.index()); - } - return droppedDims; + return ::getDroppedDims(getType().getShape(), getMixedSizes()); } FailureOr @@ -2397,23 +2422,7 @@ } // namespace llvm::SmallBitVector InsertSliceOp::getDroppedDims() { - ArrayRef resultShape = getType().getShape(); - SmallVector mixedSizes = getMixedSizes(); - llvm::SmallBitVector droppedDims(mixedSizes.size()); - unsigned shapePos = 0; - for (const auto &size : enumerate(mixedSizes)) { - std::optional sizeVal = getConstantIntValue(size.value()); - // If the size is not 1, or if the current matched dimension of the result - // is the same static shape as the size value (which is 1), then the - // dimension is preserved. - if (!sizeVal || *sizeVal != 1 || - (shapePos < resultShape.size() && resultShape[shapePos] == 1)) { - shapePos++; - continue; - } - droppedDims.set(size.index()); - } - return droppedDims; + return ::getDroppedDims(getSourceType().getShape(), getMixedSizes()); } void InsertSliceOp::getCanonicalizationPatterns(RewritePatternSet &results,