diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -530,6 +530,7 @@ } }]; let hasVerifier = 1; + let hasFolder = 1; } def Vector_ExtractOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -943,6 +943,24 @@ return success(); } +OpFoldResult vector::ExtractElementOp::fold(ArrayRef operands) { + // Skip the 0-D vector here now. + if (operands.size() < 2) + return {}; + + Attribute src = operands[0]; + Attribute pos = operands[1]; + if (!src || !pos) + return {}; + + auto srcElements = src.cast().getValues(); + + auto attr = pos.dyn_cast(); + uint64_t posIdx = attr.getInt(); + + return srcElements[posIdx]; +} + //===----------------------------------------------------------------------===// // ExtractOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1385,3 +1385,15 @@ %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> return %1 : vector<4xi32> } + +// ----- + +// CHECK-LABEL: func @extract_element_fold +// CHECK: %[[C:.+]] = arith.constant 5 : i32 +// CHECK: return %[[C]] +func @extract_element_fold() -> i32 { + %v = arith.constant dense<[1, 3, 5, 7]> : vector<4xi32> + %i = arith.constant 2 : i32 + %1 = vector.extractelement %v[%i : i32] : vector<4xi32> + return %1 : i32 +}