diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -950,7 +950,16 @@ Attribute src = operands[0]; Attribute pos = operands[1]; - if (!src || !pos) + + // We could not fold if the position is not constant. + if (!pos) + return {}; + + // fold extractelement (splat X) -> X + if (auto splat = getVector().getDefiningOp()) + return splat.getInput(); + + if (!src) return {}; auto srcElements = src.cast().getValues(); diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1397,3 +1397,13 @@ %1 = vector.extractelement %v[%i : i32] : vector<4xi32> return %1 : i32 } + +// CHECK-LABEL: func @extract_element_splat_fold +// CHECK-SAME: (%[[ARG:.+]]: i32) +// CHECK: return %[[ARG]] +func @extract_element_splat_fold(%a : i32) -> i32 { + %v = vector.splat %a : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +}