diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -655,6 +655,11 @@ /// same total number of elements as well as element type. DenseElementsAttr reshape(ShapedType newType); + /// Return a new DenseElementsAttr that has the same data as the current + /// attribute, but with a different shape for a splat type. The new type must + /// have the same element type. + DenseElementsAttr resizeSplat(ShapedType newType); + /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has bitcast elements to 'newElType'. The new type must have /// the same bitwidth as the current element type. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1238,7 +1238,12 @@ return {}; } -OpFoldResult ExtractSliceOp::fold(ArrayRef) { +OpFoldResult ExtractSliceOp::fold(ArrayRef operands) { + if (auto splat = operands[0].dyn_cast_or_null()) { + auto resultType = result().getType().cast(); + if (resultType.hasStaticShape()) + return splat.resizeSplat(resultType); + } if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -967,6 +967,18 @@ return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), isSplat()); } +DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { + assert(isSplat() && "expected a splat type"); + + ShapedType curType = getType(); + if (curType == newType) + return *this; + + assert(newType.getElementType() == curType.getElementType() && + "expected the same element type"); + return DenseIntOrFPElementsAttr::getRaw(newType, getRawData(), true); +} + /// Return a new DenseElementsAttr that has the same data as the current /// attribute, but has bitcast elements such that it is now 'newType'. The new /// type must have the same shape and element types of the same bitwidth as the diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -621,6 +621,17 @@ // ----- +// CHECK-LABEL: func @fold_extract_constant_splat +// CHECK-NOT: tensor.extract_slice +// CHECK: arith.constant dense<42> : tensor<4x4xi32> +func @fold_extract_constant_splat() -> (tensor<4x4xi32>) { + %cst = arith.constant dense<42> : tensor<1024x1024xi32> + %1 = tensor.extract_slice %cst[0,0] [4,4] [1, 1] : tensor<1024x1024xi32> to tensor<4x4xi32> + return %1 : tensor<4x4xi32> +} + +// ----- + // CHECK-LABEL: func @fold_overlapping_insert // CHECK-SAME: %[[INPUT:.+]]: tensor, %{{.+}}: tensor<4x?x8xf32>, %[[SLICE2:.+]]: tensor<4x?x8xf32> func @fold_overlapping_insert(%input : tensor, %slice1: tensor<4x?x8xf32>, %slice2: tensor<4x?x8xf32>, %i: index, %size: index) -> (tensor) {