diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -713,6 +713,7 @@ } }]; let hasVerifier = 1; + let hasFolder = 1; } def Vector_InsertOp : diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1839,6 +1839,29 @@ return success(); } +OpFoldResult vector::InsertElementOp::fold(ArrayRef operands) { + // Skip the 0-D vector here. + if (operands.size() < 3) + return {}; + + Attribute src = operands[0]; + Attribute dst = operands[1]; + Attribute pos = operands[2]; + if (!src || !dst || !pos) + return {}; + + auto dstElements = dst.cast().getValues(); + + SmallVector results(dstElements); + + auto attr = pos.dyn_cast(); + uint64_t posIdx = attr.getInt(); + + results[posIdx] = src; + + return DenseElementsAttr::get(getDestVectorType(), results); +} + //===----------------------------------------------------------------------===// // InsertOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1372,3 +1372,16 @@ %t = vector.transpose %bcast, [1, 0] : vector<8x1xf32> to vector<1x8xf32> return %t : vector<1x8xf32> } + +// ----- + +// CHECK-LABEL: func @insert_element_fold +// CHECK: %[[V:.+]] = arith.constant dense<[0, 1, 7, 3]> : vector<4xi32> +// CHECK: return %[[V]] +func @insert_element_fold() -> vector<4xi32> { + %v = arith.constant dense<[0, 1, 2, 3]> : vector<4xi32> + %s = arith.constant 7 : i32 + %i = arith.constant 2 : i32 + %1 = vector.insertelement %s, %v[%i : i32] : vector<4xi32> + return %1 : vector<4xi32> +}