diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -713,6 +713,7 @@ } }]; let hasVerifier = 1; + let hasFolder = 1; } def Vector_InsertOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1820,6 +1820,23 @@ return success(); } +OpFoldResult vector::InsertElementOp::fold(ArrayRef operands) { + Attribute src = operands[0], dst = operands[1], pos = operands[2]; + if (!src || !dst || !pos) + return {}; + + auto dstElements = dst.cast().getValues(); + + SmallVector results(dstElements); + + auto attr = pos.dyn_cast(); + uint64_t posIdx = attr.getInt(); + + results[posIdx] = src; + + return DenseElementsAttr::get(getDestVectorType(), results); +} + //===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1276,3 +1276,16 @@ %shuffle = vector.shuffle %v0, %v1 [3, 2, 5, 1] : vector<3xi32>, vector<3xi32> return %shuffle : vector<4xi32> } + +// ----- + +// CHECK-LABEL: func @insert_element_fold +// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> +// CHECK: return %[[V]] +func @insert_element_fold() -> vector<4xi32> { + %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> + %s = arith.constant 7 : i32 + %i = arith.constant 2 : i32 + %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> + return %1 : vector<4xi32> +}