diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -15,7 +15,10 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -1238,12 +1241,119 @@ return {}; } +template +static void sliceElements(IterTy values, ArrayRef counts, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, + llvm::SmallVectorImpl *outValues) { + assert(offsets.size() == sizes.size()); + assert(offsets.size() == strides.size()); + if (offsets.empty()) + return; + + int64_t offset = offsets.front(); + int64_t size = sizes.front(); + int64_t stride = strides.front(); + if (offsets.size() == 1) { + for (int i = 0; i < size; ++i, offset += stride) + outValues->push_back(*(values + offset)); + + return; + } + + for (int64_t i = 0; i < size; ++i, offset += stride) { + auto begin = values + offset * counts.front(); + sliceElements(begin, counts.drop_front(), + offsets.drop_front(), sizes.drop_front(), + strides.drop_front(), outValues); + } +} + +static Attribute foldConstant(ExtractSliceOp op) { + DenseElementsAttr attr; + if (!matchPattern(op.source(), m_Constant(&attr))) + return {}; + // TODO: Support the splat case. + if (!attr || attr.isSplat()) + return {}; + + // The case with multiple uses is not supported since it creates more + // constant data. + if (!op.source().hasOneUse()) + return {}; + + // Dynamic result shape is not supported. + auto sourceType = op.source().getType().cast(); + if (!sourceType.hasStaticShape()) + return {}; + + auto resultType = op.result().getType().cast(); + if (!resultType.hasStaticShape()) + return {}; + + // Control the size. Since the way to get a new constant collects each + // element, it can have a negative impact on the compile time when the data + // size is big. + // TODO: If the size of the constant folded is to be controlled, move this + // out of folding and make it a separate pattern which can accept an option + // to control the size. + constexpr int64_t kConstantFoldingMaxNumElements = 1024; + if (resultType.getNumElements() > kConstantFoldingMaxNumElements) + return {}; + + ArrayRef shape = sourceType.getShape(); + int64_t count = sourceType.getNumElements(); + if (count == 0) + return {}; + + // Check if there are any dynamic parts, which are not supported. + auto offsets = extractFromI64ArrayAttr(op.static_offsets()); + if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) + return {}; + auto sizes = extractFromI64ArrayAttr(op.static_sizes()); + if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) + return {}; + auto strides = extractFromI64ArrayAttr(op.static_strides()); + if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) + return {}; + + // Compute the stride for each dimension. + SmallVector counts; + counts.reserve(shape.size()); + for (int64_t v : shape) { + count = count / v; + counts.push_back(count); + } + + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(resultType, outValues); + } + + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(resultType, outValues); + } + + return {}; +} + OpFoldResult ExtractSliceOp::fold(ArrayRef) { if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; + + if (Attribute slice = foldConstant(*this)) + return slice; + return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -387,6 +387,43 @@ // ----- +// CHECK-LABEL: func @slice_constant +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32> +// CHECK: return %[[CONST]] : tensor<1x1xf32> +func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32> +{ + %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32> + %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32> + return %slice : tensor<1x1xf32> +} + +// CHECK-LABEL: func @slice_constant_3x4 +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32> +// CHECK: return %[[CONST]] : tensor<2x2xf32> +func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> +{ + %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> + %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> + return %slice : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @slice_constant_3x4_offsets +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32> +// CHECK: return %[[CONST]] : tensor<2x2xf32> +func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> +{ + %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> + %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> + return %slice : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: func @trivial_insert_slice // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: tensor.extract_slice