diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -15,7 +15,10 @@ #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -1240,12 +1243,111 @@ return {}; } +template +static void sliceElements(I values, ArrayRef counts, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, + llvm::SmallVectorImpl *outValues) { + assert(offsets.size() == sizes.size()); + assert(offsets.size() == strides.size()); + if (offsets.empty()) + return; + + int64_t offset = offsets.front(); + int64_t size = sizes.front(); + int64_t stride = strides.front(); + if (offsets.size() == 1) { + for (int i = offset; i < size; i += stride) + outValues->push_back(*(values + i)); + + return; + } + + for (; offset < size; offset += stride) { + auto begin = values + offset * counts.front(); + sliceElements(begin, counts.drop_front(), offsets.drop_front(), + sizes.drop_front(), strides.drop_front(), outValues); + } +} + +template +static Attribute foldExtractSliceAfterConstant(ExtractSliceOp op, + IterTy values) { + auto sourceType = op.source().getType().cast(); + if (!sourceType.hasStaticShape()) + return {}; + + auto shape = sourceType.getShape(); + int64_t count = sourceType.getNumElements(); + if (count == 0) { + return DenseElementsAttr::get( + op.result().getType().cast(), + /*list=*/{}); + } + + // Check if there are any dynamic parts, which are not supported. + auto offsets = extractFromI64ArrayAttr(op.static_offsets()); + if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) + return {}; + auto sizes = extractFromI64ArrayAttr(op.static_sizes()); + if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) + return {}; + auto strides = extractFromI64ArrayAttr(op.static_strides()); + if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) + return {}; + + // Compute the stride for each dimension. + SmallVector counts; + counts.reserve(shape.size()); + for (auto v : shape) { + count = count / v; + counts.push_back(count); + } + + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements(values, counts, offsets, sizes, strides, + &outValues); + + return DenseElementsAttr::get(op.result().getType().cast(), + outValues); +} + +static Attribute foldConstant(ExtractSliceOp op) { + DenseElementsAttr attr; + if (!matchPattern(op.source(), m_Constant(&attr))) + return {}; + // TODO: support the splat case + if (!attr || attr.isSplat()) + return {}; + + if (auto intElems = attr.dyn_cast()) { + if (auto folded = + foldExtractSliceAfterConstant(op, intElems.begin())) { + return folded; + } + } else if (auto floatElems = attr.dyn_cast()) { + if (auto folded = foldExtractSliceAfterConstant< + DenseElementsAttr::FloatElementIterator, APFloat>( + op, floatElems.begin())) { + return folded; + } + } + + return {}; +} + OpFoldResult ExtractSliceOp::fold(ArrayRef) { if (getSourceType() == getType() && succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType()))) return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; + + if (auto slice = foldConstant(*this)) + return slice; + return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -387,6 +387,19 @@ // ----- +// CHECK-LABEL: func @slice_constant +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32> +// CHECK: return %[[CONST]] : tensor<1x1xf32> +func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32> +{ + %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32> + %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32> + return %slice : tensor<1x1xf32> +} + +// ----- + // CHECK-LABEL: func @trivial_insert_slice // CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8> // CHECK-NOT: tensor.extract_slice