diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -2501,24 +2501,27 @@ if (!broadcast) return failure(); auto srcVecType = broadcast.getSource().getType().dyn_cast(); - unsigned srcRrank = srcVecType ? srcVecType.getRank() : 0; + unsigned srcRank = srcVecType ? srcVecType.getRank() : 0; auto dstVecType = op.getType().cast(); unsigned dstRank = dstVecType.getRank(); - unsigned rankDiff = dstRank - srcRrank; + unsigned rankDiff = dstRank - srcRank; // Check if the most inner dimensions of the source of the broadcast are the // same as the destination of the extract. If this is the case we can just // use a broadcast as the original dimensions are untouched. bool lowerDimMatch = true; - for (unsigned i = 0; i < srcRrank; i++) { + for (unsigned i = 0; i < srcRank; i++) { if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { lowerDimMatch = false; break; } } Value source = broadcast.getSource(); - if (!lowerDimMatch) { - // The inner dimensions don't match, it means we need to extract from the - // source of the orignal broadcast and then broadcast the extracted value. + // If the inner dimensions don't match, it means we need to extract from the + // source of the orignal broadcast and then broadcast the extracted value. + // We also need to handle degenerated cases where the source is effectively + // just a single scalar. + bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); + if (!lowerDimMatch && !isScalarSrc) { source = rewriter.create( op->getLoc(), source, getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -762,6 +762,34 @@ // ----- +// CHECK-LABEL: func @extract_strided_broadcast3 +// CHECK-SAME: (%[[ARG:.+]]: vector<1xf32>) +// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : vector<1xf32> to vector<1x4xf32> +// CHECK: return %[[V]] +func @extract_strided_broadcast3(%arg0: vector<1xf32>) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : vector<1xf32> to vector<1x8xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]} + : vector<1x8xf32> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + +// CHECK-LABEL: func @extract_strided_broadcast4 +// CHECK-SAME: (%[[ARG:.+]]: f32) +// CHECK: %[[V:.+]] = vector.broadcast %[[ARG]] : f32 to vector<1x4xf32> +// CHECK: return %[[V]] +func @extract_strided_broadcast4(%arg0: f32) -> vector<1x4xf32> { + %0 = vector.broadcast %arg0 : f32 to vector<1x8xf32> + %1 = vector.extract_strided_slice %0 + {offsets = [0, 4], sizes = [1, 4], strides = [1, 1]} + : vector<1x8xf32> to vector<1x4xf32> + return %1 : vector<1x4xf32> +} + +// ----- + // CHECK-LABEL: consecutive_shape_cast // CHECK: %[[C:.*]] = vector.shape_cast %{{.*}} : vector<16xf16> to vector<4x4xf16> // CHECK-NEXT: return %[[C]] : vector<4x4xf16>