diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2962,12 +2962,16 @@ SmallVector sliceMaskDimSizes; sliceMaskDimSizes.reserve(maskDimSizes.size()); for (auto [maskDimSize, sliceOffset, sliceSize] : - llvm::zip_equal(maskDimSizes, sliceOffsets, sliceSizes)) { + llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { int64_t sliceMaskDimSize = std::max( static_cast(0), std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); sliceMaskDimSizes.push_back(sliceMaskDimSize); } + // Add unchanged dimensions. + if (sliceMaskDimSizes.size() < maskDimSizes.size()) + for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) + sliceMaskDimSizes.push_back(maskDimSizes[i]); // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked // region is a conjunction of mask dim intervals). if (llvm::is_contained(sliceMaskDimSizes, 0)) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2083,3 +2083,15 @@ %1 = vector.extract %0[0, 0, 31] : vector<1x1x32x1xf32> return %1: vector<1xf32> } + +// ----- +// CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask +func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{ + // CHECK-NEXT: %[[RES:.*]] = vector.constant_mask [5, 4] : vector<5x7xi1> + // CHECK-NEXT: return %[[RES]] : vector<5x7xi1> + %c4 = arith.constant 4 : index + %c10 = arith.constant 10 : index + %mask = vector.create_mask %c10, %c4 : vector<12x7xi1> + %res = vector.extract_strided_slice %mask {offsets = [3], sizes = [5], strides = [1]} : vector<12x7xi1> to vector<5x7xi1> + return %res : vector<5x7xi1> +}