diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -652,7 +652,8 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the - /// same total number of elements as well as element type. + /// same element type. It also must have the same total number of elements + /// unless it is a splat type. DenseElementsAttr reshape(ShapedType newType); /// Return a new DenseElementsAttr that has the same data as the current diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1238,7 +1238,12 @@ return {}; } -OpFoldResult ExtractSliceOp::fold(ArrayRef) { +OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { + if (auto splat = operands[0].dyn_cast_or_null()) { + auto resultType = result().getType().cast(); + if (resultType.hasStaticShape()) + return splat.reshape(resultType); + } if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -954,7 +954,8 @@ /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has been reshaped to 'newType'. The new type must have the -/// same total number of elements as well as element type. +/// same element type. It also must have the same total number of elements +/// unless it is a splat type. DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { ShapedType curType = getType(); if (curType == newType) @@ -962,8 +963,8 @@ assert(newType.getElementType() == curType.getElementType() && "expected the same element type"); - assert(newType.getNumElements() == curType.getNumElements() && - "expected the same number of elements"); + assert((isSplat() || (newType.getNumElements() == curType.getNumElements())) + && "expected the same number of elements for a non-splat type"); return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -621,6 +621,17 @@ // ----- +// CHECK-LABEL: func @fold_extract_constant_splat +// CHECK-NOT: tensor.extract_slice +// CHECK: arith.constant dense<42> : tensor<4x4xi32> +func @fold_extract_constant_splat() -> (tensor<4x4xi32>) { + %cst = arith.constant dense<42> : tensor<1024x1024xi32> + %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32> + return %1 : tensor<4x4xi32> +} + +// ----- + // CHECK-LABEL: func @fold_overlapping_insert // CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) {