diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1151,7 +1151,8 @@ auto n = std::min(adaptor.getPosition().size(), vectorType.getRank()); inferredReturnTypes.push_back(VectorType::get( - vectorType.getShape().drop_front(n), vectorType.getElementType())); + vectorType.getShape().drop_front(n), vectorType.getElementType(), + vectorType.getScalableDims().drop_front(n))); } return success(); } diff --git a/mlir/test/Dialect/Vector/vector-transforms.mlir b/mlir/test/Dialect/Vector/vector-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-transforms.mlir @@ -34,6 +34,32 @@ return %1 : vector<2x[4]x1xf32> } +// CHECK-LABEL: func.func @cast_away_leading_one_dim( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x4x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x4x1xf32>) -> vector<1x4x1xf32> { +// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1x4x1xf32> +// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_1]][0] : vector<1x4x1xf32> +// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<4x1xf32> +// CHECK: %[[VAL_5:.*]] = vector.broadcast %[[VAL_4]] : vector<4x1xf32> to vector<1x4x1xf32> +// CHECK: return %[[VAL_5]] : vector<1x4x1xf32> +func.func @cast_away_leading_one_dim(%arg0: vector<1x4x1xf32>, %arg1: vector<1x4x1xf32>) -> vector<1x4x1xf32> { + %1 = arith.mulf %arg0, %arg1 : vector<1x4x1xf32> + return %1: vector<1x4x1xf32> +} + +// CHECK-LABEL: func.func @scalable_cast_away_leading_one_dim( +// CHECK-SAME: %[[VAL_0:.*]]: vector<1x[4]x1xf32>, +// CHECK-SAME: %[[VAL_1:.*]]: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> { +// CHECK: %[[VAL_2:.*]] = vector.extract %[[VAL_0]][0] : vector<1x[4]x1xf32> +// CHECK: %[[VAL_3:.*]] = vector.extract %[[VAL_1]][0] : vector<1x[4]x1xf32> +// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<[4]x1xf32> +// CHECK: %[[VAL_5:.*]] = vector.broadcast %[[VAL_4]] : vector<[4]x1xf32> to vector<1x[4]x1xf32> +// CHECK: return %[[VAL_5]] : vector<1x[4]x1xf32> +func.func @scalable_cast_away_leading_one_dim(%arg0: vector<1x[4]x1xf32>, %arg1: vector<1x[4]x1xf32>) -> vector<1x[4]x1xf32> { + %1 = arith.mulf %arg0, %arg1 : vector<1x[4]x1xf32> + return %1: vector<1x[4]x1xf32> +} + // CHECK-LABEL: func @add4x4 // CHECK: %[[S1:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[S2:.*]] = vector.extract_strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>