diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1270,20 +1270,39 @@ } } -template -static Attribute foldExtractSliceAfterConstant(ExtractSliceOp op, - IterTy values) { +static Attribute foldConstant(ExtractSliceOp op) { + DenseElementsAttr attr; + if (!matchPattern(op.source(), m_Constant(&attr))) + return {}; + // TODO: Support the splat case. + if (!attr || attr.isSplat()) + return {}; + + // The case with multiple uses is not supported since it creates more + // constant data. + if (!op.source().getDefiningOp()->hasOneUse()) + return {}; + + // Dynamic result shape is not supported. auto sourceType = op.source().getType().cast(); if (!sourceType.hasStaticShape()) return {}; + auto resultType = op.result().getType().dyn_cast(); + if (!resultType || !resultType.hasStaticShape()) + return {}; + + // Control the size. Sice the way to get a new constant collects each element, + // it can have a bad impact on the compile time when the data size is big. + // TODO: create an option if a customization is needed. + constexpr int64_t kConstantFoldingMaxNumElements = 1024; + if (resultType.getNumElements() > kConstantFoldingMaxNumElements) + return {}; + auto shape = sourceType.getShape(); int64_t count = sourceType.getNumElements(); - if (count == 0) { - return DenseElementsAttr::get( - op.result().getType().cast(), - /*list=*/{}); - } + if (count == 0) + return {}; // Check if there are any dynamic parts, which are not supported. auto offsets = extractFromI64ArrayAttr(op.static_offsets()); @@ -1304,52 +1323,22 @@ counts.push_back(count); } - SmallVector outValues; - outValues.reserve(sourceType.getNumElements()); - sliceElements(values, counts, offsets, sizes, strides, - &outValues); - - return DenseElementsAttr::get(op.result().getType().cast(), - outValues); -} - -static Attribute foldConstant(ExtractSliceOp op) { - DenseElementsAttr attr; - if (!matchPattern(op.source(), m_Constant(&attr))) - return {}; - // TODO: Support the splat case. - if (!attr || attr.isSplat()) - return {}; - - // The case with multiple uses is not supported since it creates more - // constant data. - if (!op.source().getDefiningOp()->hasOneUse()) - return {}; - - // Dynamic result shape is not supported. - auto resultType = op.result().getType().dyn_cast(); - if (!resultType || !resultType.hasStaticShape()) - return {}; - - // Control the size. Sice the way to get a new constant collects each element, - // it can have a bad impact on the compile time when the data size is big. - // TODO: create an option if a customization is needed. - constexpr int64_t kConstantFoldingMaxNumElements = 1024; - if (resultType.getNumElements() > kConstantFoldingMaxNumElements) - return {}; + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(op.result().getType().cast(), + outValues); + } - if (auto intElems = attr.dyn_cast()) { - if (auto folded = - foldExtractSliceAfterConstant(op, intElems.begin())) { - return folded; - } - } else if (auto floatElems = attr.dyn_cast()) { - if (auto folded = foldExtractSliceAfterConstant< - DenseElementsAttr::FloatElementIterator, APFloat>( - op, floatElems.begin())) { - return folded; - } + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + return DenseElementsAttr::get(op.result().getType().cast(), + outValues); } return {};