diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1467,18 +1467,21 @@ return Value(); auto broadcastOp = cast(defOp); - int64_t rankDiff = broadcastSrcRank - extractResultRank; + int64_t broadcastDstRank = broadcastOp.getResultVectorType().getRank(); + // Detect all the positions that come from "dim-1" broadcasting. // These dimensions correspond to "dim-1" broadcasted dims; set the mathching // extract position to `0` when extracting from the source operand. llvm::SetVector broadcastedUnitDims = broadcastOp.computeBroadcastedUnitDims(); SmallVector extractPos(extractOp.getPosition()); - for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) + int64_t broadcastRankDiff = broadcastDstRank - broadcastSrcRank; + for (int64_t i = broadcastRankDiff, e = extractPos.size(); i < e; ++i) if (broadcastedUnitDims.contains(i)) extractPos[i] = 0; // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the // matching extract position when extracting from the source operand. + int64_t rankDiff = broadcastSrcRank - extractResultRank; extractPos.erase(extractPos.begin(), std::next(extractPos.begin(), extractPos.size() - rankDiff)); // OpBuilder is only used as a helper to build an I64ArrayAttr. @@ -4953,7 +4956,8 @@ OpFoldResult vector::TransposeOp::fold(FoldAdaptor adaptor) { // Eliminate splat constant transpose ops. - if (auto attr = llvm::dyn_cast_if_present(adaptor.getVector())) + if (auto attr = + llvm::dyn_cast_if_present(adaptor.getVector())) if (attr.isSplat()) return attr.reshape(getResultVectorType()); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2104,6 +2104,15 @@ return %1: vector<1xf32> } +// CHECK-LABEL: func.func @extract_from_stretch_broadcast +func.func @extract_from_stretch_broadcast(%src: vector<3x1x2xf32>) -> f32 { + // CHECK-NEXT: %0 = vector.extract {{.*}}[0, 0, 0] : vector<3x1x2xf32> + // CHECK-NEXT: return %0 : f32 + %0 = vector.broadcast %src : vector<3x1x2xf32> to vector<3x4x2xf32> + %1 = vector.extract %0[0, 2, 0] : vector<3x4x2xf32> + return %1: f32 +} + // ----- // CHECK-LABEL: func.func @extract_strided_slice_of_constant_mask func.func @extract_strided_slice_of_constant_mask() -> vector<5x7xi1>{