diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1046,6 +1046,11 @@ if (auto splat = getVector().getDefiningOp()) return splat.getInput(); + // Fold extractelement(broadcast(X)) -> X. + if (auto broadcast = getVector().getDefiningOp()) + if (!broadcast.getSource().getType().isa()) + return broadcast.getSource(); + if (!pos || !src) return {}; diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -2095,3 +2095,15 @@ %res = vector.extract_strided_slice %mask {offsets = [3], sizes = [5], strides = [1]} : vector<12x7xi1> to vector<5x7xi1> return %res : vector<5x7xi1> } + +// ----- + +// CHECK-LABEL: func.func @fold_extractelement_of_broadcast( +// CHECK-SAME: %[[f:.*]]: f32 +// CHECK: return %[[f]] +func.func @fold_extractelement_of_broadcast(%f: f32) -> f32 { + %0 = vector.broadcast %f : f32 to vector<15xf32> + %c5 = arith.constant 5 : index + %1 = vector.extractelement %0 [%c5 : index] : vector<15xf32> + return %1 : f32 +}