diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -23,12 +23,22 @@ // Returns `vector<1xT>` if `oldType` only has one element. static VectorType trimLeadingOneDims(VectorType oldType) { ArrayRef oldShape = oldType.getShape(); - ArrayRef newShape = - oldShape.drop_while([](int64_t dim) { return dim == 1; }); + ArrayRef newShape(oldShape.begin(), oldShape.end()); + + ArrayRef oldScalableDims = oldType.getScalableDims(); + ArrayRef newScalableDims(oldScalableDims); + + while (!newShape.empty() && newShape.front() == 1 && !newScalableDims.front()) { + newShape = newShape.drop_front(1); + newScalableDims = newScalableDims.drop_front(1); + } + // Make sure we have at least 1 dimension per vector type requirements. - if (newShape.empty()) + if (newShape.empty()) { newShape = oldShape.take_back(); - return VectorType::get(newShape, oldType.getElementType()); + newScalableDims = oldType.getScalableDims().take_back(); + } + return VectorType::get(newShape, oldType.getElementType(), newScalableDims); } /// Return a smallVector of size `rank` containing all zeros. diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -18,6 +18,22 @@ return %1: vector<4x2xf32> } +// Regression test. Previously, this example would trigger +// CastAwayElementwiseLeadingOneDim as: +// * `vector<2x[4]x1xf32>`, would be reformulated as +// * `vector<2x4x1xf32>`. +// With the updated shape, the conversion pattern would incorrectly assume that +// some leading dims have been dropped. +// CHECK-LABEL: func.func @no_change( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2x[4]x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<2x[4]x1xf32>) +// CHECK-NEXT: %[[VAL_2:.*]] = arith.mulf %[[VAL_0]], %[[VAL_1]] : vector<2x[4]x1xf32> +// CHECK-NEXT: return %[[VAL_2]] +func.func @no_change(%arg0: vector<2x[4]x1xf32>, %arg1: vector<2x[4]x1xf32>) -> vector<2x[4]x1xf32> { + %1 = arith.mulf %arg0, %arg1 : vector<2x[4]x1xf32> + return %1 : vector<2x[4]x1xf32> +} + // CHECK-LABEL: func @add4x4 // CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -55,6 +55,7 @@ void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + registry.insert(); } Option unroll{*this, "unroll", llvm::cl::desc("Include unrolling"),