diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -169,10 +169,8 @@ unsigned newPosRank = std::max(0, oldPosRank - dstDropCount); SmallVector newPositions = llvm::to_vector( insertOp.getPosition().getValue().take_back(newPosRank)); - if (srcDropCount >= dstDropCount) { - auto zeroAttr = rewriter.getZeroAttr(rewriter.getI64Type()); - newPositions.resize(newPosRank + srcDropCount, zeroAttr); - } + newPositions.resize(newDstType.getRank() - newSrcRank, + rewriter.getI64IntegerAttr(0)); auto newInsertOp = rewriter.create( loc, newDstType, newSrcVector, newDstVector, diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -316,3 +316,15 @@ %0 = vector.insert %s, %v [5] : vector<1x4xf32> into vector<8x1x4xf32> return %0: vector<8x1x4xf32> } + +// CHECK-LABEL: func @cast_away_insert_leading_one_dims_one_two_dest +// CHECK-SAME: (%[[S:.+]]: vector<1x8xi1>, %[[V:.+]]: vector<1x1x8x1x8xi1>) +// CHECK: %[[EXTRACTS:.+]] = vector.extract %[[S]][0] : vector<1x8xi1> +// CHECK: %[[EXTRACTV:.+]] = vector.extract %[[V]][0, 0] : vector<1x1x8x1x8xi1> +// CHECK: %[[INSERT:.+]] = vector.insert %[[EXTRACTS]], %[[EXTRACTV]] [7, 0] : vector<8xi1> into vector<8x1x8xi1> +// CHECK: %[[BCAST:.+]] = vector.broadcast %[[INSERT]] : vector<8x1x8xi1> to vector<1x1x8x1x8xi1> +// CHECK: return %[[BCAST]] +func.func @cast_away_insert_leading_one_dims_one_two_dest(%s: vector<1x8xi1>, %v: vector<1x1x8x1x8xi1>) -> vector<1x1x8x1x8xi1> { + %0 = vector.insert %s, %v [0, 0, 7] : vector<1x8xi1> into vector<1x1x8x1x8xi1> + return %0: vector<1x1x8x1x8xi1> +}