diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1238,12 +1238,19 @@ return {}; } -OpFoldResult ExtractSliceOp::fold(ArrayRef) { +OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { + if (auto splat = operands[0].dyn_cast_or_null()) { + auto resultType = result().getType().cast(); + if (resultType.hasStaticShape()) + return SplatElementsAttr::get(resultType, + splat.getSplatValue()); + } if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; + return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -621,6 +621,17 @@ // ----- +// CHECK-LABEL: func @fold_extract_constant_splat +// CHECK-NOT: tensor.extract_slice +// CHECK: arith.constant dense<42> : tensor<4x4xi32> +func @fold_extract_constant_splat() -> (tensor<4x4xi32>) { + %cst = arith.constant dense<42> : tensor<1024x1024xi32> + %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32> + return %1 : tensor<4x4xi32> +} + +// ----- + // CHECK-LABEL: func @fold_overlapping_insert // CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) {