diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -15,7 +15,10 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -1238,12 +1241,116 @@ return {}; } +template +static void sliceElements(IterTy values, ArrayRef counts, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, + llvm::SmallVectorImpl *outValues) { + assert(offsets.size() == sizes.size()); + assert(offsets.size() == strides.size()); + if (offsets.empty()) + return; + + int64_t offset = offsets.front(); + int64_t size = sizes.front(); + int64_t stride = strides.front(); + if (offsets.size() == 1) { + for (int i = offset; i < size; i += stride) + outValues->push_back(*(values + i)); + + return; + } + + for (; offset < size; offset += stride) { + auto begin = values + offset * counts.front(); + sliceElements(begin, counts.drop_front(), + offsets.drop_front(), sizes.drop_front(), + strides.drop_front(), outValues); + } +} + +static Attribute foldConstant(ExtractSliceOp op) { + DenseElementsAttr attr; + if (!matchPattern(op.source(), m_Constant(&attr))) + return {}; + // TODO: Support the splat case. + if (!attr || attr.isSplat()) + return {}; + + // The case with multiple uses is not supported since it creates more + // constant data. + if (!op.source().getDefiningOp()->hasOneUse()) + return {}; + + // Dynamic result shape is not supported. + auto sourceType = op.source().getType().cast(); + if (!sourceType.hasStaticShape()) + return {}; + + auto resultType = op.result().getType().cast(); + if (!resultType.hasStaticShape()) + return {}; + + // Control the size. Sice the way to get a new constant collects each element, + // it can have a bad impact on the compile time when the data size is big. + // TODO: create an option if a customization is needed. + constexpr int64_t kConstantFoldingMaxNumElements = 1024; + if (resultType.getNumElements() > kConstantFoldingMaxNumElements) + return {}; + + auto shape = sourceType.getShape(); + int64_t count = sourceType.getNumElements(); + if (count == 0) + return {}; + + // Check if there are any dynamic parts, which are not supported. + auto offsets = extractFromI64ArrayAttr(op.static_offsets()); + if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) + return {}; + auto sizes = extractFromI64ArrayAttr(op.static_sizes()); + if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) + return {}; + auto strides = extractFromI64ArrayAttr(op.static_strides()); + if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) + return {}; + + // Compute the stride for each dimension. + SmallVector counts; + counts.reserve(shape.size()); + for (auto v : shape) { + count = count / v; + counts.push_back(count); + } + + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(resultType, outValues); + } + + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(resultType, outValues); + } + + return {}; +} + OpFoldResult ExtractSliceOp::fold(ArrayRef) { if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; + + if (auto slice = foldConstant(*this)) + return slice; + return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -387,6 +387,19 @@ // ----- +// CHECK-LABEL: func @slice_constant +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32> +// CHECK: return %[[CONST]] : tensor<1x1xf32> +func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32> +{ + %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32> + %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32> + return %slice : tensor<1x1xf32> +} + +// ----- + // CHECK-LABEL: func @trivial_insert_slice // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: tensor.extract_slice