diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -9,6 +9,7 @@ #ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H #define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMS_H +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/PatternMatch.h" namespace mlir { @@ -20,6 +21,19 @@ void populateSplitPaddingPatterns(RewritePatternSet &patterns, PatternBenefit baseBenefit = 1); +/// Function to control the folding of constant and extract slice +using ControlConstantExtractSliceFusionFn = std::function; + +/// Patterns to fold the extract slice op with its constant operand +void populateFoldConstantExtractSlicePatterns( + RewritePatternSet &patterns, + const ControlConstantExtractSliceFusionFn &controlFn = + [](ExtractSliceOp op) { + // Is it disabled by default because the folding can generate a large + // constant tensor, which would affect the compile time and storage. + return false; + }); + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -10,12 +10,16 @@ #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" @@ -1158,8 +1162,125 @@ return success(); } }; + +template +static void sliceElements(IterTy values, ArrayRef counts, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, + llvm::SmallVectorImpl *outValues) { + assert(offsets.size() == sizes.size()); + assert(offsets.size() == strides.size()); + if (offsets.empty()) + return; + + int64_t offset = offsets.front(); + int64_t size = sizes.front(); + int64_t stride = strides.front(); + if (offsets.size() == 1) { + for (int64_t i = 0; i < size; ++i, offset += stride) + outValues->push_back(*(values + offset)); + + return; + } + + for (int64_t i = 0; i < size; ++i, offset += stride) { + auto begin = values + offset * counts.front(); + sliceElements(begin, counts.drop_front(), + offsets.drop_front(), sizes.drop_front(), + strides.drop_front(), outValues); + } +} + +class ConstantOpExtractSliceFolder final + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + ConstantOpExtractSliceFolder(MLIRContext *context, + ControlConstantExtractSliceFusionFn controlFn) + : OpRewritePattern(context), + controlFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(ExtractSliceOp op, + PatternRewriter &rewriter) const override { + DenseElementsAttr attr; + if (!matchPattern(op.source(), m_Constant(&attr))) + return failure(); + + // A constant splat is handled by fold(). + if (attr.isSplat()) + return failure(); + + // Dynamic result shape is not supported. + auto sourceType = op.source().getType().cast(); + auto resultType = op.result().getType().cast(); + if (!sourceType.hasStaticShape() || !resultType.hasStaticShape()) + return failure(); + + // customized control over the folding + if (!controlFn(op)) + return failure(); + + int64_t count = sourceType.getNumElements(); + if (count == 0) + return failure(); + + // Check if there are any dynamic parts, which are not supported. + auto offsets = extractFromI64ArrayAttr(op.static_offsets()); + if (llvm::is_contained(offsets, ShapedType::kDynamicStrideOrOffset)) + return failure(); + auto sizes = extractFromI64ArrayAttr(op.static_sizes()); + if (llvm::is_contained(sizes, ShapedType::kDynamicSize)) + return failure(); + auto strides = extractFromI64ArrayAttr(op.static_strides()); + if (llvm::is_contained(strides, ShapedType::kDynamicStrideOrOffset)) + return failure(); + + // Compute the stride for each dimension. + SmallVector counts; + ArrayRef shape = sourceType.getShape(); + counts.reserve(shape.size()); + for (int64_t v : shape) { + count = count / v; + counts.push_back(count); + } + + DenseElementsAttr newAttr; + + if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + newAttr = DenseElementsAttr::get(resultType, outValues); + } else if (auto elems = attr.dyn_cast()) { + SmallVector outValues; + outValues.reserve(sourceType.getNumElements()); + sliceElements( + elems.begin(), counts, offsets, sizes, strides, &outValues); + newAttr = DenseElementsAttr::get(resultType, outValues); + } + + if (newAttr) { + rewriter.replaceOpWithNewOp(op, resultType, newAttr); + return success(); + } + + return failure(); + } + +private: + ControlConstantExtractSliceFusionFn controlFn; +}; + } // namespace +void mlir::tensor::populateFoldConstantExtractSlicePatterns( + RewritePatternSet &patterns, + const ControlConstantExtractSliceFusionFn &controlFn) { + patterns.add(patterns.getContext(), controlFn); +} + /// Return the canonical type of the result of an extract_slice op. struct SliceReturnTypeCanonicalizer { RankedTensorType operator()(ExtractSliceOp op, @@ -1238,6 +1359,7 @@ return this->source(); if (Value slice = foldExtractAfterInsertSlice(*this)) return slice; + return OpFoldResult(); } diff --git a/mlir/test/Dialect/Tensor/fold-constant-extractslice.mlir b/mlir/test/Dialect/Tensor/fold-constant-extractslice.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/fold-constant-extractslice.mlir @@ -0,0 +1,39 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-constant-extract-slice %s | FileCheck %s + +// CHECK-LABEL: func @slice_constant +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<1.000000e+01> : tensor<1x1xf32> +// CHECK: return %[[CONST]] : tensor<1x1xf32> +func @slice_constant(%arg0 : tensor<2x1xf32>) -> tensor<1x1xf32> +{ + %cst = arith.constant dense<[[10.0], [11.0]]> : tensor<2x1xf32> + %slice = tensor.extract_slice %cst[0, 0] [1, 1] [1, 1] : tensor<2x1xf32> to tensor<1x1xf32> + return %slice : tensor<1x1xf32> +} + +// ----- + +// CHECK-LABEL: func @slice_constant_3x4 +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.000000e+01, 9.000000e+00], [1.100000e+01, 1.200000e+01]]> : tensor<2x2xf32> +// CHECK: return %[[CONST]] : tensor<2x2xf32> +func @slice_constant_3x4(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> +{ + %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> + %slice = tensor.extract_slice %cst[0, 0] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> + return %slice : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: func @slice_constant_3x4_offsets +// CHECK-NOT: tensor.extract_slice +// CHECK: %[[CONST:.+]] = arith.constant dense<{{\[}}[1.200000e+01, 1.300000e+01], [3.000000e+00, 5.000000e+00]]> : tensor<2x2xf32> +// CHECK: return %[[CONST]] : tensor<2x2xf32> +func @slice_constant_3x4_offsets(%arg0 : tensor<3x4xf32>) -> tensor<2x2xf32> +{ + %cst = arith.constant dense<[[10.0, 9.0, 8.0, 7.0], [11.0, 12.0, 13.0, 14.0], [1.0, 3.0, 5.0, 7.0]]> : tensor<3x4xf32> + %slice = tensor.extract_slice %cst[1, 1] [2, 2] [1, 1] : tensor<3x4xf32> to tensor<2x2xf32> + return %slice : tensor<2x2xf32> +} + diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -41,6 +41,11 @@ *this, "test-split-padding-patterns", llvm::cl::desc("Test patterns to split tensor.pad ops"), llvm::cl::init(false)}; + + Option testFoldConstantExtractSlice{ + *this, "test-fold-constant-extract-slice", + llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), + llvm::cl::init(false)}; }; } // namespace @@ -50,10 +55,31 @@ (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); } +static void applyFoldConstantExtractSlicePatterns(FuncOp funcOp) { + RewritePatternSet patterns(funcOp.getContext()); + tensor::ControlConstantExtractSliceFusionFn controlFn = + [](tensor::ExtractSliceOp op) { + if (!op.source().hasOneUse()) + return false; + + auto resultType = op.result().getType().cast(); + constexpr int64_t kConstantFoldingMaxNumElements = 1024; + if (resultType.getNumElements() > kConstantFoldingMaxNumElements) + return false; + + return true; + }; + + tensor::populateFoldConstantExtractSlicePatterns(patterns, controlFn); + (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { FuncOp func = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(func); + if (testFoldConstantExtractSlice) + applyFoldConstantExtractSlicePatterns(func); } namespace mlir {