diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -161,13 +161,17 @@ Value newDstVector = rewriter.create( loc, insertOp.getDest(), splatZero(dstDropCount)); + // New position rank needs to be computed in two steps: (1) if destination + // type has leading unit dims, we also trim the position array accordingly, + // then (2) if source type also has leading unit dims, we need to append + // zeroes to the position array accordingly. unsigned oldPosRank = insertOp.getPosition().getValue().size(); - unsigned newPosRank = newDstType.getRank() - newSrcRank; + unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); SmallVector newPositions = llvm::to_vector( insertOp.getPosition().getValue().take_back(newPosRank)); - if (newPosRank > oldPosRank) { + if (srcDropCount >= dstDropCount) { auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); - newPositions.resize(newPosRank, zeroAttr); + newPositions.resize(newPosRank + srcDropCount, zeroAttr); } auto newInsertOp = rewriter.create( diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -295,6 +295,18 @@ return %0: vector<1x1x4xf32> } +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_rank2_one_dest +// CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<1x2x1x4xf32>) +// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x4xf32> +// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0] : vector<1x2x1x4xf32> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [1, 0] : vector<4xf32> into vector<2x1x4xf32> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<2x1x4xf32> to vector<1x2x1x4xf32> +// CHECK: return %[[BCAST]] +func.func @cast_away_insert_leading_one_dims_rank2_one_dest(%s: vector<1x4xf32>, %v: vector<1x2x1x4xf32>) -> vector<1x2x1x4xf32> { + %0 = vector.insert %s, %v [0, 1] : vector<1x4xf32> into vector<1x2x1x4xf32> + return %0: vector<1x2x1x4xf32> +} + // CHECK-LABEL: func @cast_away_insert_leading_one_dims_non_one_dest // CHECK-SAME: (%[[S:.+]]: vector<1x4xf32>, %[[V:.+]]: vector<8x1x4xf32>) // CHECK: %[[EXTRACT:.+]] = vector.extract %[[S]][0] : vector<1x4xf32>