diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1632,6 +1632,13 @@ // folding patterns. if (extractResultRank < broadcastSrcRank) return failure(); + + // Special case if broadcast src is a 0D vector. + if (extractResultRank == 0) { + assert(broadcastSrcRank == 0 && source.getType().isa()); + rewriter.replaceOpWithNewOp(extractOp, source); + return success(); + } rewriter.replaceOpWithNewOp( extractOp, extractOp.getType(), source); return success(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -519,6 +519,18 @@ // ----- +// CHECK-LABEL: fold_extract_broadcast_0dvec +// CHECK-SAME: %[[A:.*]]: vector +// CHECK: %[[B:.+]] = vector.extractelement %[[A]][] : vector +// CHECK: return %[[B]] : f32 +func.func @fold_extract_broadcast_0dvec(%a : vector) -> f32 { + %b = vector.broadcast %a : vector to vector<1x2x4xf32> + %r = vector.extract %b[0, 1, 2] : vector<1x2x4xf32> + return %r : f32 +} + +// ----- + // CHECK-LABEL: fold_extract_broadcast_negative // CHECK: vector.broadcast %{{.*}} : vector<1x1xf32> to vector<1x1x4xf32> // CHECK: vector.extract %{{.*}}[0, 0] : vector<1x1x4xf32>