diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -283,7 +283,11 @@ // Build an ExtractSliceOp with dynamic entries and inferred result type. OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractSliceOp with mixed static and dynamic entries packed in + // a Range vector. + OpBuilder<(ins "Value":$source, "ArrayRef":$ranges, + CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -739,6 +743,11 @@ // Build a InsertSliceOp with dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertSliceOp with mixed static and dynamic entries packed in + // a Range vector. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$ranges, CArg<"ArrayRef", "{}">:$attrs)> ]; @@ -1337,7 +1346,11 @@ "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, CArg<"ArrayRef", "{}">:$attrs)>, - + // Build a ParallelInsertSliceOp with mixed static and dynamic entries + // packed into a Range vector. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$ranges, + CArg<"ArrayRef", "{}">:$attrs)>, // Build a ParallelInsertSliceOp with dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/TransformUtils.h @@ -0,0 +1,210 @@ +//===- TransformsUtils.h - Tensor Transformation Utilities-------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H +#define MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H + +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/PatternMatch.h" + +namespace mlir { +namespace tensor { + +//===----------------------------------------------------------------------===// +// Extract slice from `tensor.collapse_shape` +//===----------------------------------------------------------------------===// + +/// This class assists with generating IR required to materialize an +/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to +/// accomplish this, a loop nest or similar operation must be created by the +/// caller. The purpose of the loop nest is to generate a "tiling by 1" of all +/// sliced dimensions. The "tiling by 1" assembles all elements of the result +/// tile over dimensions that would have been impossible to directly slice. +/// +/// The class provides three methods: +/// 1. `ExtractSliceFromCollapseHelper::create`: emits IR that should +/// appear before the loop nest and populates the internal state. +/// 2. `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`: returns +/// parameters used by the caller to construct the loop nest. +/// 3. `ExtractSliceFromCollapseHelper::emitLoopNestBody`: +/// emits IR to construct a "size-1 tile" of the desired result and returns a +/// set of ranges where the tile should be inserted into the destination +/// tensor. +/// +/// ### Intended usage: +/// +/// The caller should first call `ExtractSliceFromCollapseHelper::create` and +/// then create a destination tensor that is the same size as the desired slice. +/// The caller then creates a loop nest that iterates over the multi-dimensional +/// iteration space defined by `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]` +/// where `ub` is the upper bound given by +/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. Inside the body of +/// the loop nest, the caller should call +/// `ExtractSliceFromCollapseHelper::emitLoopNestBody` and provide the induction +/// variables. This returns a sub-tile and a set of ranges that describe where +/// this tile should be inserted into the result by the caller. For a complete +/// example of usage, see the examples in the TestTensorTransforms pass. +/// +/// ### Example: +/// Consider the following IR: +/// ``` +/// %0 = linalg.generic ... -> tensor<3x?x?x11x?xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3, 4]] +/// : tensor<3x?x?x11x?xf32> into tensor +/// %2 = tensor.extract_slice %1 [%offt0, %offt1][%size0, %size1][1, 1] +/// : tensor to tensor +/// ``` +/// +/// We can construct %2 by generating the following, which only uses `%0`: +/// +/// ``` +/// %dest = linalg.init_tensor [%size0, %size1] : tensor +/// %1 = tensor.dim %0, %c1 : tensor<3x?x?x11x?xf32> +/// %2 = tensor.dim %0, %c2 : tensor<3x?x?x11x?xf32> +/// %3 = tensor.dim %0, %c4 : tensor<3x?x?x11x?xf32> +/// +/// %result = scf.for %iv0 = %c0 to %arg2 step %c1 iter_args(%arg6 = %dest) -> +/// (tensor) { +/// %5 = scf.for %iv1 = %c0 to %arg4 step %c1 iter_args(%arg8 = %arg6) +/// -> (tensor) { +/// %lin0 = (affine.apply) %iv0 + %offt0 +/// %lin1 = (affine.apply) %iv1 + %offt1 +/// +/// %mi0:3 = affine.delinearize_index %lin0 into (%c3, %1, %2) +/// %mi1:2 = affine.delinearize_index %lin1 into (%c11, %3) +/// +/// %sub_tile = tensor.extract_slice %0 +/// [%mi0#0, %mi0#1, %mi0#2, %mi1#0, %mi1#1] +/// [1, 1, 1, 1, 1] +/// [1, 1, 1, 1, 1] +/// : tensor<3x?x?x11x?xf32> to tensor<1x1x1x1x1xf32> +/// %sub_tile_collapsed = tensor.collapse_shape %sub_tile +/// [[0, 1, 2], [3, 4]] +/// : tensor<1x1x1x1x1xf32> into tensor<1x1xf3 +/// +/// %12 = tensor.insert_slice %sub_tile_collapsed into +/// %arg8[%iv0, %iv1] [1, 1] [1, 1] +/// : tensor<1x1xf32> into tensor +/// scf.yield %12 : tensor +/// } +/// scf.yield %5 : tensor +/// } +/// ``` +/// +/// ### Explanation of example: +/// +/// Each step above is explained below. +/// +/// #### Step 0: Create %dest and materialization of shapes. +/// This step is self-explanatory and performed by the caller. It can be +/// done before or after calling `ExtractSliceFromCollapseHelper::create`, +/// which materializes the source shape (`%0, %1, %2`). +/// +/// #### Step 1: Create loop nest. +/// +/// The caller creates the loop nest (depicted here is `scf.for`, but any other +/// similar op can be used). The iteration should start at zero and proceed with +/// step size 1 to the upper bounds given by +/// `ExtractSliceFromCollapseHelper::getIterationSpaceSizes`. This forms the +/// basis for the "tiling by 1". +/// +/// #### Step 2: Transform (%iv0, %iv1) from the index space of %3 to the index +/// space of %0. +/// +/// This step is performed by +/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`. +/// +/// The induction variables `%iv0` and `%iv1` live in the +/// index space of %2 (for dimensions 0 and 1, respectively). `%lin0` and +/// `%lin1` are the result of inverting or resolve the index space +/// transformation represented by the slice operation, accounting for offset and +/// stride. Subsequently, `%mi0` and `%mi1` are the result of applying the +/// inverse index space transformation represented by `tensor.collapse_shape`. +/// This is accomplished using `affine.delinearize_index`. Note that %iv0 +/// and %iv1 now correspond to multi-indices `%mi0:3` and `%mi1:2`. +/// +/// #### Step 3: Extract a sub-tile slice from the source. +/// +/// This step is also performed by +/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`. +/// +/// The indices `%mi0` and `%mi1` are used to extract a slice from %0. This +/// slice is then collapsed down to match the result rank. +/// +/// #### Step 4: Insert sub-tile into the destination +/// +/// This step is performed by the caller using the results of +/// `ExtractSliceFromCollapseHelper::emitLoopNestBody`. +/// +/// In the above example, the slice insertion parameters are straightforward, +/// but in other possible situations, the slice parameters are more complicated, +/// which is why this helper calculates them for the caller. These other +/// situations correspond to: +/// 1. The presence of linearized dimensions that are not sliced +/// 2. The presence of non-linearized dimensions that are sliced. +class ExtractSliceFromCollapseHelper { +public: + /// Given a CollapseShapeOp and a set of ranges describing the desired slice + /// of its result, emits IR to materialize the shapes of the input and output + /// tensors, and returns an instance of the initialized class. Returns failure + /// if the slice is rank-reducing. + static FailureOr + create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef sliceParams); + + /// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits + /// IR to materialize the shapes of the input and output tensors of the + /// CollapseShapeOp, and returns an instance of the initialized class. Returns + /// failure if the slice is rank-reducing. + static FailureOr + create(OpBuilder &b, tensor::CollapseShapeOp collapseOp, + tensor::ExtractSliceOp extractOp); + + ExtractSliceFromCollapseHelper( + tensor::CollapseShapeOp collapseShapeOp, + ArrayRef collapseShapeInputShape, + ArrayRef collapseShapeOutputShape, + ArrayRef extractSliceParams, + const llvm::SmallBitVector &linearizedDimensions, + const llvm::SmallBitVector &slicedDimensions, ArrayRef tiledSizes) + : collapseShapeOp(collapseShapeOp), + collapseShapeInputShape(collapseShapeInputShape), + collapseShapeOutputShape(collapseShapeOutputShape), + sliceParams(extractSliceParams), + linearizedDimensions(linearizedDimensions), + slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {} + + /// Return the upper bounds of the iteration space (with 0 offset and stride + /// 1) required to create the desired slice. Note that this is not the same + /// as the `sizes` parameters of the ExtractSliceOp because not all dimensions + /// of the slice are required to be tiled to form the result. + const SmallVector &getIterationSpaceSizes() { return tiledSizes; } + + /// Generates the IR inside of the caller's loop nest for 1) inverting the + /// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2) + /// extracting the CollapseShapeOp source tensor tile for this specified + /// iteration space point `tileInductionVars` and 3) calculating where to + /// insert the extracted tile. The returned pair consists of the results of + /// (2) and (3) and should be used by the caller to insert into the + /// destination tensor. + std::pair> + emitLoopNestBody(OpBuilder &builder, Location loc, + ValueRange tileInductionVars); + +private: + tensor::CollapseShapeOp collapseShapeOp; + SmallVector collapseShapeInputShape; + SmallVector collapseShapeOutputShape; + SmallVector sliceParams; + llvm::SmallBitVector linearizedDimensions; + llvm::SmallBitVector slicedDimensions; + SmallVector tiledSizes; +}; + +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_TRANSFORMS_TRANSFORMUTILS_H diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -14,6 +14,7 @@ #ifndef MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H #define MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" @@ -373,6 +374,90 @@ } }; +/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular +/// non rank-reducing slice of the collapse_shape output. Try to find which +/// dimensions have been sliced and which dimensions are not sliced (offset = 0, +/// size = dim, size = 1). Note that this conservative as it cannot detect if a +/// dynamic size corresponds to the full tensor dimension or not. +llvm::SmallBitVector getSlicedDimensions(ArrayRef sliceInputShape, + ArrayRef sliceParams); + +/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by +/// inspecting its reassociation indices. +llvm::SmallBitVector +getLinearizedDimensions(ArrayRef reassociationIndices); + +/// Given the parameters for both operations in a `CollapseShape->ExtractSlice` +/// chain and reified source and result shapes of the CollapseShapeOp, this +/// class provides two functions that assist with directly forming the result +/// of the extract slice by "tiling the CollapseShapeOp by 1". +//// Example: +// clang-format off +/// ``` +/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32> +/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32> +/// ``` +/// This class helps build the below IR to replace %2: +/// ``` +/// %dest = linalg.init_tensor() : tensor<10x10xf32> +/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { +/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) +/// %3:3 = arith.delinearize_index %iv into (3, 7, 11) +/// +/// // This function takes %3 (multiIndices) and the parameters for the slice below. +/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : +/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> +/// +/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : +/// tensor<1x1x1x10xf32> into tensor<1x10xf32> +/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : +/// tensor<1x10xf32> into tensor<10x10xf32> +/// scf.yield %6 : tensor<10x10xf32> +/// } +/// ``` +// clang-format on +class SliceFromCollapseHelper { +public: + SliceFromCollapseHelper(ArrayRef reassociationIndices, + ArrayRef collapseShapeInputShape, + ArrayRef collapseShapeOutputShape, + ArrayRef extractSliceParams) + : reassociationIndices(reassociationIndices), + collapseShapeInputShape(collapseShapeInputShape), + collapseShapeOutputShape(collapseShapeOutputShape), + sliceParams(extractSliceParams), + linearizedDimensions(getLinearizedDimensions(reassociationIndices)), + slicedDimensions(getSlicedDimensions(collapseShapeOutputShape, + extractSliceParams)) {} + + /// This function takes multi-indices and maps them to ExtractSlice parameters + /// in the index space of the CollapseShape's source tensor. This function's + /// signature can be described by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes, + /// strides)` where `n` the number of "tiled dimensions", which are the + /// dimensions of the output that are linearized by the collapse shape op and + /// are also sliced. Each `D_i` is a tuple that must represent a valid + /// multi-index for the `i-th` tiled dimension. In the example above, there is + /// only one tiled dimension (D_0) and `arith.delinearize_index` produces the + /// multi-index (%3) that would be passed to this function to generate the + /// parameters for the `tensor.extract_slice` op (%4). + SmallVector getExtractSliceParams(ArrayRef multiIndices); + + /// This function takes indices in the index space of the "tiled dimensions" + /// described above and returns a set of Range variables that describe how the + /// slice should be inserted into the destination. In the example above, `%iv` + /// would be passed to this function to generate the parameters for the + /// `tensor.insert_slice` op producing %6. + SmallVector getInsertSliceParams(ValueRange tileIndices); + +private: + SmallVector reassociationIndices; + SmallVector collapseShapeInputShape; + SmallVector collapseShapeOutputShape; + SmallVector sliceParams; + llvm::SmallBitVector linearizedDimensions; + llvm::SmallBitVector slicedDimensions; +}; } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h --- a/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StaticValueUtils.h @@ -29,6 +29,13 @@ OpFoldResult stride; }; +/// Given an array of Range values, return a tuple of (offset vector, sizes +/// vector, and strides vector) formed by separating out the individual elements +/// of each range. +std::tuple, SmallVector, + SmallVector> +getOffsetsSizesAndStrides(ArrayRef ranges); + /// Helper function to dispatch an OpFoldResult into `staticVec` if: /// a) it is an IntegerAttr /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1210,6 +1210,15 @@ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } +/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a +/// Range vector. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, + ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); +} + /// Build an ExtractSliceOp with dynamic entries and custom result type. If the /// type passed is nullptr, it is inferred. void ExtractSliceOp::build(OpBuilder &b, OperationState &result, @@ -1597,6 +1606,15 @@ result.addAttributes(attrs); } +/// Build an InsertSliceOp with mixed static and dynamic entries packed into a +/// Range vector. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, source, dest, offsets, sizes, strides, attrs); +} + // Build a InsertSliceOp with dynamic entries. void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, ValueRange sizes, @@ -2359,6 +2377,16 @@ result.addAttributes(attrs); } +/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed +/// into a Range vector. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, + ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, source, dest, offsets, sizes, strides, attrs); +} + // Build a ParallelInsertSliceOp with dynamic entries. void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ExtractSliceFromReshape.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -0,0 +1,179 @@ +//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that replace slices of reshape results with +// aggregated slices of the reshape source. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Get the dimension size of a value of RankedTensor type at the +OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, + int64_t dimIdx) { + RankedTensorType tensorType = rankedTensor.getType().cast(); + if (!tensorType.isDynamicDim(dimIdx)) { + return b.getIndexAttr(tensorType.getDimSize(dimIdx)); + } + Value idxValue = b.create(loc, dimIdx); + return b.createOrFold(loc, rankedTensor, idxValue); +} + +/// Get all the dimension sizes of a value of RankedTensor type. +static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, + Value rankedTensor) { + SmallVector dimSizes; + RankedTensorType tensorType = rankedTensor.getType().cast(); + for (unsigned i = 0; i < tensorType.getRank(); i++) + dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); + return dimSizes; +} + +/// A tuple that represents (dimension number, dimension value). +using DimAndIndex = std::tuple; + +/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing) +/// slice described by `sliceParams` into the input index space. +static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc, + ArrayRef sliceParams, + const DimAndIndex &dimAndIndex) { + AffineExpr d0, s0, s1; + bindDims(b.getContext(), d0); + bindSymbols(b.getContext(), s0, s1); + auto [dim, indexValue] = dimAndIndex; + assert(dim < sliceParams.size() && "slice should be non rank-reducing"); + return std::make_pair( + dim, + makeComposedAffineApply( + b, loc, s0 + d0 * s1, + {indexValue, + getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset), + getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)})); +} + +/// Transform `dimAndIndex` from the result tensor index space of a +/// CollapseShapeOp to the source tensor index space. +static ValueRange invertCollapseShapeIndexing( + OpBuilder &b, Location loc, ArrayRef reassociation, + ArrayRef reshapeSourceShape, const DimAndIndex &dimAndIndex) { + const auto &[dim, indexValue] = dimAndIndex; + SmallVector basis; + for (int64_t i : reassociation[dim]) + basis.push_back(reshapeSourceShape[i]); + auto delinearized = + b.create(loc, indexValue, basis); + return delinearized->getResults(); +} + +FailureOr +tensor::ExtractSliceFromCollapseHelper::create( + OpBuilder &b, tensor::CollapseShapeOp collapseOp, + tensor::ExtractSliceOp extractOp) { + if (extractOp.getSource().getDefiningOp() != + collapseOp) + return failure(); + SmallVector ranges; + ranges.reserve(extractOp.getSourceType().getRank()); + for (const auto &[o, s, st] : + llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides())) { + ranges.push_back({o, s, st}); + } + return ExtractSliceFromCollapseHelper::create(b, collapseOp, ranges); +} + +FailureOr +tensor::ExtractSliceFromCollapseHelper::create(OpBuilder &b, + tensor::CollapseShapeOp op, + ArrayRef sliceParams) { + + // Materialize the output shape of the collapse_shape operation. This will + // create IR describing the output shape in terms of the input shape. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + return failure(); + SmallVector collapseShapeOutputShape = + getAsOpFoldResult(reifiedShapes[0]); + SmallVector reassociationIndices = + op.getReassociationIndices(); + + // Determine which of the CollapseShapeOp's result dimensions are sliced + // and/or linearized. + llvm::SmallBitVector linearizedDimensions = + getLinearizedDimensions(reassociationIndices); + llvm::SmallBitVector slicedDimensions = + getSlicedDimensions(collapseShapeOutputShape, sliceParams); + + auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc()); + + SmallVector srcShape = + getShapeDimSizes(b, op->getLoc(), op.getSrc()); + + SmallVector tileSizes; + for (unsigned i = 0; i < sliceParams.size(); i++) { + if (slicedDimensions[i] && linearizedDimensions[i]) + tileSizes.push_back( + getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size)); + } + + return ExtractSliceFromCollapseHelper( + op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams, + linearizedDimensions, slicedDimensions, tileSizes); +} + +std::pair> +tensor::ExtractSliceFromCollapseHelper::emitLoopNestBody( + OpBuilder &builder, Location loc, ValueRange tileInductionVars) { + // Create the helper class for forming the slice parameters. + const SmallVector reassociationIndices = + collapseShapeOp.getReassociationIndices(); + SliceFromCollapseHelper helper(reassociationIndices, collapseShapeInputShape, + collapseShapeOutputShape, sliceParams); + + // Get the indices of the tiled dims (linearized by the collapse_shape + // and sliced by the extract_slice) invert the index spaces + // transformations. + SmallVector multiIndices; + unsigned loopIdx = 0; + for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) { + if (linearizedDimensions[i] && slicedDimensions[i]) { + DimAndIndex tb = + invertSliceIndexing(builder, loc, sliceParams, + std::make_tuple(i, tileInductionVars[loopIdx++])); + multiIndices.push_back(invertCollapseShapeIndexing( + builder, loc, reassociationIndices, collapseShapeInputShape, tb)); + } + } + + auto extractParams = helper.getExtractSliceParams(multiIndices); + + Value subTileResult = builder.create( + loc, collapseShapeOp.getSrc(), extractParams); + + SmallVector insertParams = + helper.getInsertSliceParams(tileInductionVars); + + // Collapse the dimensions of the source slice back down. + Value collapsedResult = builder.create( + loc, subTileResult, reassociationIndices); + return std::make_pair(collapsedResult, insertParams); +} diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -270,3 +270,88 @@ return !memrefType.getLayout().isIdentity(); return false; } + +llvm::SmallBitVector +mlir::getSlicedDimensions(ArrayRef sliceInputShape, + ArrayRef sliceParams) { + assert(sliceParams.size() == sliceInputShape.size() && + "only supports non rank-reducing case"); + llvm::SmallBitVector mask(sliceInputShape.size()); + unsigned idx = 0; + for (const auto &[offset, size, stride] : sliceParams) { + Optional offsetConst = getConstantIntValue(offset); + Optional strideConst = getConstantIntValue(stride); + mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) || + (!strideConst || *strideConst != 1) || + (!offsetConst || *offsetConst != 0); + idx++; + } + return mask; +} + +llvm::SmallBitVector mlir::getLinearizedDimensions( + ArrayRef reassociationIndices) { + llvm::SmallBitVector result(reassociationIndices.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) + result[it.index()] = it.value().size() > 1; + return result; +} + +SmallVector SliceFromCollapseHelper::getExtractSliceParams( + ArrayRef multiIndices) { + assert(!multiIndices.empty() && !multiIndices[0].empty() && + "multiIndices should not be empty"); + unsigned loopIdx = 0; + MLIRContext *ctx = multiIndices[0][0].getContext(); + auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); + auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); + SmallVector offsetsSizesAndStrides; + offsetsSizesAndStrides.reserve(collapseShapeInputShape.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: Linearized dimensions that have also been sliced. These + // are size of 1 because we are iterating over these dimensions. The + // offsets are exactly the de-linearized multi-indices. + if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) { + llvm::append_range( + offsetsSizesAndStrides, + llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range { + return Range{getAsOpFoldResult(v), oneAttr, oneAttr}; + })); + continue; + } + + // Case 2: One or possibly multiple combined input dimensions, but we + // have proven that these are not sliced. In this case we just take + // the full extent of each dimension in the reassociation list. + if (linearizedDimensions[it.index()]) { + llvm::append_range( + offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; + })); + continue; + } + + // Case 3: A single index, but it may be sliced. + offsetsSizesAndStrides.push_back(sliceParams[it.index()]); + } + return offsetsSizesAndStrides; +} + +SmallVector +SliceFromCollapseHelper::getInsertSliceParams(ValueRange tileIndices) { + MLIRContext *ctx = tileIndices[0].getContext(); + auto one = IntegerAttr::get(IndexType::get(ctx), 1); + auto zero = IntegerAttr::get(IndexType::get(ctx), 0); + SmallVector insertParams; + insertParams.reserve(linearizedDimensions.size()); + unsigned loopIdx = 0; + for (unsigned i = 0; i < linearizedDimensions.size(); i++) { + if (linearizedDimensions[i] && slicedDimensions[i]) { + insertParams.push_back(Range{tileIndices[loopIdx++], one, one}); + continue; + } + insertParams.push_back(Range{zero, sliceParams[i].size, one}); + } + return insertParams; +} diff --git a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp --- a/mlir/lib/Dialect/Utils/StaticValueUtils.cpp +++ b/mlir/lib/Dialect/Utils/StaticValueUtils.cpp @@ -13,6 +13,21 @@ namespace mlir { +std::tuple, SmallVector, + SmallVector> +getOffsetsSizesAndStrides(ArrayRef ranges) { + SmallVector offsets, sizes, strides; + offsets.reserve(ranges.size()); + sizes.reserve(ranges.size()); + strides.reserve(ranges.size()); + for (const auto &[offset, size, stride] : ranges) { + offsets.push_back(offset); + sizes.push_back(size); + strides.push_back(stride); + } + return std::make_tuple(offsets, sizes, strides); +} + /// Helper function to dispatch an OpFoldResult into `staticVec` if: /// a) it is an IntegerAttr /// In other cases, the OpFoldResult is dispached to the `dynamicVec`. diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH + +func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32> + return %slice : tensor<20x11xf32> +} + +// CHECK: func.func @extract_slice_static(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + +// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: +// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index +// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index +// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// FOREACH: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) shared_outs(%[[dest:.+]] = %[[init]]) +// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// FOREACH: perform_concurrently +// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[iv]], 0] [1, 11] [1, 1] : +// FOREACH: return %[[tile]] + +// ----- + + +func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32> + return %slice : tensor<10x5xf32> +} + +// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)> +// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]]) +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + + +// ----- + + +func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> +// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]] +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] : + +// ----- + + +func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) +// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) +// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]] +// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]] +// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: scf.yield %[[tile2]] : +// CHECK: return %[[tile1]] : + +// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index +// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index +// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) shared_outs(%[[dest:.+]] = %[[init]]) +// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]] +// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]] +// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) : +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// FOREACH: perform_concurrently +//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[dest]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] : + +// ----- + +// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this +// only works for static shapes. + +// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>, +func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index + // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] + return %slice : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRLinalgDialect MLIRPass MLIRSCFDialect MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -11,8 +11,10 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/TransformUtils.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,7 +30,8 @@ TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { @@ -49,6 +52,19 @@ *this, "test-fold-constant-extract-slice", llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + + Option testRewriteExtractSliceWithTiledCollapseShape{ + *this, "test-rewrite-extract-slice-from-collapse-shape", + llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " + "with loop nest"), + llvm::cl::init(false)}; + + Option useForeach{ + *this, "use-foreach", + llvm::cl::desc( + "Use the scf.foreach_thread operation when generating loop nests for " + "the extract_slice of collapse_shape pattern"), + llvm::cl::init(false)}; }; } // namespace @@ -74,12 +90,142 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +namespace { +/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`. +/// The `tensor.extract_slice` is replaced by a loop or gather operation that +/// stitches together the desired tile from slices of the source of the collapse +/// shape op. +struct RewriteExtractSliceFromCollapseShapeBase + : public OpRewritePattern { + RewriteExtractSliceFromCollapseShapeBase(MLIRContext *context) + : mlir::OpRewritePattern(context) {} + + /// Emit a loop or gather operation that uses `helper` to take each point in + /// the parallel iteration space bounds, extract a slice from the source + /// tensor and insert it into `dest`. For examples, see below for `scf.for` + /// and `scf.foreach`. + virtual LogicalResult + emitReplacement(tensor::ExtractSliceOp op, Value dest, + tensor::ExtractSliceFromCollapseHelper &helper, + PatternRewriter &rewriter) const = 0; + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override { + auto collapseOp = op.getSource().getDefiningOp(); + if (!collapseOp) + return rewriter.notifyMatchFailure( + op, "producer is not a tensor.collapse_shape op"); + + // Materialize the output shape values of the slice operation.a + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); + + // Create the destination tensor using the above values. + Type elementType = op.getSourceType().getElementType(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + Value dest = rewriter.create( + op->getLoc(), outputShape, elementType); + + // Calculate the parameters for the tile loop nest. + FailureOr params = + tensor::ExtractSliceFromCollapseHelper::create(rewriter, collapseOp, + op); + if (failed(params)) + return rewriter.notifyMatchFailure( + op, "could not calculate tiling parameters"); + return emitReplacement(op, dest, *params, rewriter); + } +}; + +struct RewriteExtractSliceFromCollapseShapeUsingScfFor + : public RewriteExtractSliceFromCollapseShapeBase { + RewriteExtractSliceFromCollapseShapeUsingScfFor(MLIRContext *context) + : RewriteExtractSliceFromCollapseShapeBase(context) {} + LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest, + tensor::ExtractSliceFromCollapseHelper &helper, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + const unsigned numTiledDims = helper.getIterationSpaceSizes().size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + SmallVector lbs(numTiledDims, zero); + SmallVector steps(numTiledDims, one); + scf::LoopNest nest = scf::buildLoopNest( + rewriter, loc, lbs, helper.getIterationSpaceSizes(), steps, dest, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, + ValueRange iterArgs) -> scf::ValueVector { + auto [tile, insertParams] = + helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); + + // Insert the slice into the destination. + Value result = nestedBuilder.create( + loc, tile, iterArgs[0], insertParams); + return {result}; + }); + rewriter.replaceOp(op, nest.getResults()[0]); + return success(); + } +}; + +struct RewriteExtractSliceFromCollapseShapeUsingScfForeach + : public RewriteExtractSliceFromCollapseShapeBase { + RewriteExtractSliceFromCollapseShapeUsingScfForeach(MLIRContext *context) + : RewriteExtractSliceFromCollapseShapeBase(context) {} + LogicalResult emitReplacement(tensor::ExtractSliceOp op, Value dest, + tensor::ExtractSliceFromCollapseHelper &helper, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto foreachOp = rewriter.create( + loc, /*outputs=*/dest, /*numThreads=*/helper.getIterationSpaceSizes(), + /*threadDimMapping=*/ArrayRef{}, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange regionArgs) { + unsigned numThreadIdRegionArgs = + helper.getIterationSpaceSizes().size(); + unsigned numOutputRegionArgs = + regionArgs.size() - numThreadIdRegionArgs; + ValueRange outputIvs = regionArgs.take_front(numThreadIdRegionArgs); + ValueRange outputArgs = regionArgs.take_back(numOutputRegionArgs); + assert(outputArgs.size() == 1 && + "there should only be one output region argument"); + auto [tile, insertParams] = + helper.emitLoopNestBody(nestedBuilder, loc, outputIvs); + // Insert the slice into the destination. + auto term = nestedBuilder.create(loc); + nestedBuilder.setInsertionPointToStart(term.getBody()); + nestedBuilder.create( + loc, tile, outputArgs[0], insertParams); + }); + rewriter.replaceOp(op, foreachOp->getResult(0)); + return success(); + } +}; +} // namespace + +static LogicalResult +applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, + bool useForeach) { + RewritePatternSet patterns(rootOp->getContext()); + if (useForeach) + patterns.add( + rootOp->getContext()); + else + patterns.add( + rootOp->getContext()); + return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testRewriteExtractSliceWithTiledCollapseShape) { + if (failed( + applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) + return signalPassFailure(); + } } namespace mlir { diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -5061,11 +5061,13 @@ "include/mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h", "include/mlir/Dialect/Tensor/Transforms/Passes.h", "include/mlir/Dialect/Tensor/Transforms/Transforms.h", + "include/mlir/Dialect/Tensor/Transforms/TransformUtils.h" ], includes = ["include"], deps = [ ":AffineDialect", ":ArithmeticDialect", + ":ArithmeticUtils", ":BufferizationDialect", ":BufferizationTransforms", ":DialectUtils", diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -620,6 +620,7 @@ includes = ["lib/Dialect/Test"], deps = [ "//mlir:ArithmeticDialect", + "//mlir:LinalgDialect", "//mlir:Pass", "//mlir:SCFDialect", "//mlir:TensorDialect",