diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -283,7 +283,11 @@ // Build an ExtractSliceOp with dynamic entries and inferred result type. OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, - CArg<"ArrayRef", "{}">:$attrs)> + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an ExtractSliceOp with mixed static and dynamic entries packed in + // a Range vector. + OpBuilder<(ins "Value":$source, "ArrayRef":$ranges, + CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ @@ -601,6 +605,11 @@ // Build a InsertSliceOp with dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, + CArg<"ArrayRef", "{}">:$attrs)>, + // Build an InsertSliceOp with mixed static and dynamic entries packed in + // a Range vector. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$ranges, CArg<"ArrayRef", "{}">:$attrs)> ]; @@ -1199,7 +1208,11 @@ "ArrayRef":$offsets, "ArrayRef":$sizes, "ArrayRef":$strides, CArg<"ArrayRef", "{}">:$attrs)>, - + // Build a ParallelInsertSliceOp with mixed static and dynamic entries + // packed into a Range vector. + OpBuilder<(ins "Value":$source, "Value":$dest, + "ArrayRef":$ranges, + CArg<"ArrayRef", "{}">:$attrs)>, // Build a ParallelInsertSliceOp with dynamic entries. OpBuilder<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -29,6 +29,110 @@ FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); +//===----------------------------------------------------------------------===// +// Extract slice from `tensor.collapse_shape` +//===----------------------------------------------------------------------===// + +/// This class assists with generating IR required to materialize an +/// arbitrary-sized slice from the result of a CollapseShapeOp. In order to +/// accomplish this, a loop-nest or similar operation must be created by the +/// caller. The class provides two methods: the `create` method will +/// emit necessary IR that should appear before the loop and populate the +/// internal state of the class. The caller should then create a new destination +/// tensor that is the same size as the desired slice and a loop-nest that +/// iterates over the the multi-dimensional iteration space defined by +/// `[0, ub[0]) x [0, ub[1]] x ... x [0, ub[N-1]]` where `ub` is the upper bound +/// and can be found by calling `getIterationSpaceSizes`. Inside the body of the +/// loop nest, the caller should call `emitExtractSliceFromCollapseShapeBody` +/// and provide the induction variables. This method returns a sub-tile of the +/// desired slice result and a set of ranges that describe where this tile +/// should be inserted into the result by the caller. For a complete example of +/// usage, see the TestTensorTransforms test pass. The below example illustrates +/// the pattern: +// clang-format off +/// ``` +/// %0 = linalg.generic ... -> tensor<3x7x11x10xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] : ... to tensor<341x10xf32> +/// %2 = tensor.extract_slice %1 [13, 0] [10, 10] [2, 1] : .... tensor<10x10xf32> +/// ``` +/// We can construct %2 by generating the following IR: +/// ``` +/// %dest = linalg.init_tensor() : tensor<10x10xf32> +/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0) -> tensor<10x10xf32> { +/// // Step 1: Map this output idx (%iv) to a multi-index for the input (%3): +/// %linear_index = affine.apply affine_map<(d0)[]->(d0*2 + 11)>(%iv) +/// %3:3 = arith.delinearize_index %iv into (3, 7, 11) +/// // Step 2: Extract the slice from the input +/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] : +/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32> +/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] : +/// tensor<1x1x1x10xf32> into tensor<1x10xf32> +/// // Step 3: Insert the slice into the destination +/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] : +/// tensor<1x10xf32> into tensor<10x10xf32> +/// scf.yield %6 : tensor<10x10xf32> +/// } +/// ``` +// clang-format on +class ExtractSliceFromCollapseShapeBuilder { +public: + /// Given a CollapseShapeOp and a set of ranges describing the desired slice + /// of its result, emits IR to materialize the shapes of the input and output + /// tensors, and returns an instance of the initialized class. Returns failure + /// if the slice is rank-reducing. + static FailureOr + create(OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef sliceParams); + + /// Given a CollapseShapeOp and an ExtractSliceOp acting on its result, emits + /// IR to materialize the shapes of the input and output tensors of the + /// CollapseShapeOp, and returns an instance of the initialized class. Returns + /// failure if the slice is rank-reducing. + static FailureOr + create(OpBuilder &b, tensor::CollapseShapeOp collapseOp, + tensor::ExtractSliceOp extractOp); + + ExtractSliceFromCollapseShapeBuilder( + tensor::CollapseShapeOp collapseShapeOp, + ArrayRef collapseShapeInputShape, + ArrayRef collapseShapeOutputShape, + ArrayRef extractSliceParams, + const llvm::SmallBitVector &linearizedDimensions, + const llvm::SmallBitVector &slicedDimensions, ArrayRef tiledSizes) + : collapseShapeOp(collapseShapeOp), + collapseShapeInputShape(collapseShapeInputShape), + collapseShapeOutputShape(collapseShapeOutputShape), + sliceParams(extractSliceParams), + linearizedDimensions(linearizedDimensions), + slicedDimensions(slicedDimensions), tiledSizes(tiledSizes) {} + + /// Return the upper bounds of the iteration space (with 0 offset and stride + /// 1) required to create the desired slice. Note that this is not the same + /// as the `sizes` parameters of the ExtractSliceOp because not all dimensions + /// of the slice are required to be tiled to form the result. + const SmallVector &getIterationSpaceSizes() { return tiledSizes; } + + /// Generates the IR inside of the caller's loop nest for 1) inverting the + /// index mappings of the ExtractSliceOp->CollapseShapeOp chain and 2) + /// extracting the CollapseShapeOp source tensor tile for this specified + /// iteration space point `tileInductionVars` and 3) calculating where to + /// insert the extracted tile. The returned pair consists of the results of + /// (2) and (3) and should be used by the caller to insert into the + /// destination tensor. + std::pair> + emitExtractSliceFromCollapseShapeLoopNestBody(OpBuilder &builder, + Location loc, + ValueRange tileInductionVars); + +private: + tensor::CollapseShapeOp collapseShapeOp; + SmallVector collapseShapeInputShape; + SmallVector collapseShapeOutputShape; + SmallVector sliceParams; + llvm::SmallBitVector linearizedDimensions; + llvm::SmallBitVector slicedDimensions; + SmallVector tiledSizes; +}; + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" @@ -373,6 +374,71 @@ } }; +/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular +/// non rank-reducing slice of the collapse_shape output. Try to find which +/// dimensions have been sliced and which dimensions are not sliced (offset = 0, +/// size = dim, size = 1). Note that this conservative as it cannot detect if a +/// dynamic size corresponds to the full tensor dimension or not. +llvm::SmallBitVector getSlicedDimensions(ArrayRef sliceInputShape, + ArrayRef sliceParams); + +/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by +/// inspecting its reassociation indices. +llvm::SmallBitVector +getLinearizedDimensions(ArrayRef reassociationIndices); + +/// Given the parameters for both operations in a `CollapseShape->ExtractSlice` +/// chain and reified source and result shapes of the CollapseShapeOp, this +/// method provides two functions that assist with directly forming the result +/// of the extract slice by "tiling the CollapseShapeOp by 1". +class ExtractShapeExtractSliceBuilder { +public: + ExtractShapeExtractSliceBuilder( + ArrayRef reassociationIndices, + ArrayRef collapseShapeInputShape, + ArrayRef collapseShapeOutputShape, + ArrayRef extractSliceParams) + : reassociationIndices(reassociationIndices), + collapseShapeInputShape(collapseShapeInputShape), + collapseShapeOutputShape(collapseShapeOutputShape), + sliceParams(extractSliceParams), + linearizedDimensions(getLinearizedDimensions(reassociationIndices)), + slicedDimensions(getSlicedDimensions(collapseShapeOutputShape, + extractSliceParams)) {} + + /// This function takes multi-indices created by inverting the + /// `CollapseShape->ExtractSlice` index-space transformations and maps the + /// multi-indices to ExtractSlice parameters in the index space of the + /// CollapseShape's source tensor. This function's signature can be described + /// by `(D_0, D_1,.. D_{n-1}) -> (offsets, sizes, strides)` where `n` + /// `n` is equal to the number of output dimensions that correspond to + /// multiple dimensions of the CollapseShape's source and are sliced by the + /// ExtractSliceOp. We say `n` is the number of "tile dimensions". Each `D_i` + /// is a tuple that must represent a valid multi-index for the `i-th` tile + /// dimension. Concretely, if the input shape is `[s0, s1, s2, s3]` and the + /// output shape is `[s0 * s1, s2 * s3]`, and both dimensions are sliced, then + /// `D_0 = (d0, d1)`, `D_1 = (d1, d2)` corresponding to coordinates for the + /// two collapsed input dimensions of size `s0` and `s1`. The multi-index + /// `D_0` must then satisfy `0 <= d1*s0+d1 < s0*s1`. The function only needs + /// to accept the multi-indices for the tiled dimensions because the Ranges + /// for dimensions that are either not sliced or not linearized can be + /// immediately inferred from the parameters of `CollapseShape->ExtractSlice` + /// given in the constructor. + SmallVector getExtractSliceParams(ArrayRef multiIndices); + + /// This function takes indices in the index space of the "tiled dimensions" + /// described above and returns a set of Range variables that describe how the + /// slice should be inserted into the destination. + SmallVector getInsertSliceParams(ValueRange tileIndices); + +private: + SmallVector reassociationIndices; + SmallVector collapseShapeInputShape; + SmallVector collapseShapeOutputShape; + SmallVector sliceParams; + llvm::SmallBitVector linearizedDimensions; + llvm::SmallBitVector slicedDimensions; +}; } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -30,6 +30,13 @@ OpFoldResult stride; }; +/// Given an array of Range values, return a tuple of (offset vector, sizes +/// vector, and strides vector) formed by separating out the individual elements +/// of each range. +std::tuple, SmallVector, + SmallVector> +getOffsetsSizesAndStrides(ArrayRef ranges); + /// Return a vector of OpFoldResults given the special value /// that indicates whether of the value is dynamic or not. SmallVector getMixedValues(ArrayAttr staticValues, diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -1109,6 +1109,15 @@ build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); } +/// Build an ExtractSliceOp with mixed static and dynamic entries packed into a +/// Range vector. +void ExtractSliceOp::build(OpBuilder &b, OperationState &result, Value source, + ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, RankedTensorType(), source, offsets, sizes, strides, attrs); +} + /// Build an ExtractSliceOp with dynamic entries and custom result type. If the /// type passed is nullptr, it is inferred. void ExtractSliceOp::build(OpBuilder &b, OperationState &result, @@ -1506,6 +1515,15 @@ result.addAttributes(attrs); } +/// Build an InsertSliceOp with mixed static and dynamic entries packed into a +/// Range vector. +void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, + Value dest, ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, source, dest, offsets, sizes, strides, attrs); +} + // Build a InsertSliceOp with dynamic entries. void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, ValueRange sizes, @@ -2298,6 +2316,16 @@ result.addAttributes(attrs); } +/// Build an ParallelInsertSliceOp with mixed static and dynamic entries packed +/// into a Range vector. +void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, + Value source, Value dest, + ArrayRef ranges, + ArrayRef attrs) { + auto [offsets, sizes, strides] = getOffsetsSizesAndStrides(ranges); + build(b, result, source, dest, offsets, sizes, strides, attrs); +} + // Build a ParallelInsertSliceOp with dynamic entries. void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ExtractSliceFromReshape.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -0,0 +1,181 @@ +//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that replace slices of reshape results with +// aggregated slices of the reshape source. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Get the dimension size of a value of RankedTensor type at the +OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, + int64_t dimIdx) { + RankedTensorType tensorType = rankedTensor.getType().cast(); + if (!tensorType.isDynamicDim(dimIdx)) { + return b.getIndexAttr(tensorType.getDimSize(dimIdx)); + } + Value idxValue = b.create(loc, dimIdx); + return b.createOrFold(loc, rankedTensor, idxValue); +} + +/// Get all the dimension sizes of a value of RankedTensor type. +static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, + Value rankedTensor) { + SmallVector dimSizes; + RankedTensorType tensorType = rankedTensor.getType().cast(); + for (unsigned i = 0; i < tensorType.getRank(); i++) + dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); + return dimSizes; +} + +/// A tuple that represents (dimension number, dimension value). +using DimAndIndex = std::tuple; + +/// Transform `dimAndIndex` from the output index space of a (non-rank-reducing) +/// slice described by `sliceParams` into the input index space. +static DimAndIndex invertSliceIndexing(OpBuilder &b, Location loc, + ArrayRef sliceParams, + const DimAndIndex &dimAndIndex) { + AffineExpr d0, s0, s1; + bindDims(b.getContext(), d0); + bindSymbols(b.getContext(), s0, s1); + auto [dim, indexValue] = dimAndIndex; + assert(dim < sliceParams.size() && "slice should be non rank-reducing"); + return std::make_pair( + dim, + makeComposedAffineApply( + b, loc, s0 + d0 * s1, + {indexValue, + getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].offset), + getValueOrCreateConstantIndexOp(b, loc, sliceParams[dim].stride)})); +} + +/// Transform `dimAndIndex` from the result tensor index space of a +/// CollapseShapeOp to the source tensor index space. +static ValueRange invertCollapseShapeIndexing( + OpBuilder &b, Location loc, ArrayRef reassociation, + ArrayRef reshapeSourceShape, const DimAndIndex &dimAndIndex) { + const auto &[dim, indexValue] = dimAndIndex; + SmallVector basis; + for (int64_t i : reassociation[dim]) + basis.push_back(reshapeSourceShape[i]); + auto delinearized = + b.create(loc, indexValue, basis); + return delinearized->getResults(); +} + +FailureOr +tensor::ExtractSliceFromCollapseShapeBuilder::create( + OpBuilder &b, tensor::CollapseShapeOp collapseOp, + tensor::ExtractSliceOp extractOp) { + if (extractOp.getSource().getDefiningOp() != + collapseOp) + return failure(); + auto ranges = llvm::to_vector(llvm::map_range( + llvm::zip(extractOp.getMixedOffsets(), extractOp.getMixedSizes(), + extractOp.getMixedStrides()), + [&](const auto &it) -> Range { + const auto &[o, s, st] = it; + return Range{o, s, st}; + })); + return ExtractSliceFromCollapseShapeBuilder::create(b, collapseOp, ranges); +} + +FailureOr +tensor::ExtractSliceFromCollapseShapeBuilder::create( + OpBuilder &b, tensor::CollapseShapeOp op, ArrayRef sliceParams) { + + // Materialize the output shape of the collapse_shape operation. This will + // create IR describing the output shape in terms of the input shape. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + return failure(); + SmallVector collapseShapeOutputShape = + getAsOpFoldResult(reifiedShapes[0]); + SmallVector reassociationIndices = + op.getReassociationIndices(); + + // Determine which of the CollapseShapeOp's result dimensions are sliced + // and/or linearized. + llvm::SmallBitVector linearizedDimensions = + getLinearizedDimensions(reassociationIndices); + llvm::SmallBitVector slicedDimensions = + getSlicedDimensions(collapseShapeOutputShape, sliceParams); + + auto collapseShapeInputShape = getShapeDimSizes(b, op.getLoc(), op.getSrc()); + + SmallVector srcShape = + getShapeDimSizes(b, op->getLoc(), op.getSrc()); + + SmallVector tileSizes; + for (unsigned i = 0; i < sliceParams.size(); i++) { + if (slicedDimensions[i] && linearizedDimensions[i]) + tileSizes.push_back( + getValueOrCreateConstantIndexOp(b, op.getLoc(), sliceParams[i].size)); + } + + return ExtractSliceFromCollapseShapeBuilder( + op, collapseShapeInputShape, collapseShapeOutputShape, sliceParams, + linearizedDimensions, slicedDimensions, tileSizes); +} + +std::pair> +tensor::ExtractSliceFromCollapseShapeBuilder:: + emitExtractSliceFromCollapseShapeLoopNestBody( + OpBuilder &builder, Location loc, ValueRange tileInductionVars) { + // Create the helper class for forming the slice parameters. + const SmallVector reassociationIndices = + collapseShapeOp.getReassociationIndices(); + ExtractShapeExtractSliceBuilder helper(reassociationIndices, + collapseShapeInputShape, + collapseShapeOutputShape, sliceParams); + + // Get the indices of the tiled dims (linearized by the collapse_shape + // and sliced by the extract_slice) invert the index spaces + // transformations. + SmallVector multiIndices; + unsigned loopIdx = 0; + for (unsigned i = 0, e = linearizedDimensions.size(); i < e; i++) { + if (linearizedDimensions[i] && slicedDimensions[i]) { + DimAndIndex tb = + invertSliceIndexing(builder, loc, sliceParams, + std::make_tuple(i, tileInductionVars[loopIdx++])); + multiIndices.push_back(invertCollapseShapeIndexing( + builder, loc, reassociationIndices, collapseShapeInputShape, tb)); + } + } + + auto extractParams = helper.getExtractSliceParams(multiIndices); + + Value subTileResult = builder.create( + loc, collapseShapeOp.getSrc(), extractParams); + + SmallVector insertParams = + helper.getInsertSliceParams(tileInductionVars); + + // Collapse the dimensions of the source slice back down. + Value collapsedResult = builder.create( + loc, subTileResult, reassociationIndices); + return std::make_pair(collapsedResult, insertParams); +} diff --git a/mlir/lib/Dialect/Utils/CMakeLists.txt b/mlir/lib/Dialect/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Utils/CMakeLists.txt @@ -6,4 +6,5 @@ LINK_LIBS PUBLIC MLIRIR + MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -8,8 +8,11 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ViewLikeInterface.h" #include @@ -270,3 +273,88 @@ return !memrefType.getLayout().isIdentity(); return false; } + +llvm::SmallBitVector +mlir::getSlicedDimensions(ArrayRef sliceInputShape, + ArrayRef sliceParams) { + assert(sliceParams.size() == sliceInputShape.size() && + "only supports non rank-reducing case"); + llvm::SmallBitVector mask(sliceInputShape.size()); + unsigned idx = 0; + for (const auto &[offset, size, stride] : sliceParams) { + Optional offsetConst = getConstantIntValue(offset); + Optional strideConst = getConstantIntValue(stride); + mask[idx] = !isEqualConstantIntOrValue(size, sliceInputShape[idx]) || + (!strideConst || *strideConst != 1) || + (!offsetConst || *offsetConst != 0); + idx++; + } + return mask; +} + +llvm::SmallBitVector mlir::getLinearizedDimensions( + ArrayRef reassociationIndices) { + llvm::SmallBitVector result(reassociationIndices.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) + result[it.index()] = it.value().size() > 1; + return result; +} + +SmallVector ExtractShapeExtractSliceBuilder::getExtractSliceParams( + ArrayRef multiIndices) { + assert(!multiIndices.empty() && !multiIndices[0].empty() && + "multiIndices should not be empty"); + unsigned loopIdx = 0; + MLIRContext *ctx = multiIndices[0][0].getContext(); + auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); + auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); + SmallVector offsetsSizesAndStrides; + offsetsSizesAndStrides.reserve(collapseShapeInputShape.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: Linearized dimensions that have also been sliced. These + // are size of 1 because we are iterating over these dimensions. The + // offsets are exactly the de-linearized multi-indices. + if (slicedDimensions[it.index()] && linearizedDimensions[it.index()]) { + llvm::append_range( + offsetsSizesAndStrides, + llvm::map_range(multiIndices[loopIdx++], [&](Value v) -> Range { + return Range{getAsOpFoldResult(v), oneAttr, oneAttr}; + })); + continue; + } + + // Case 2: One or possibly multiple combined input dimensions, but we + // have proven that these are not sliced. In this case we just take + // the full extent of each dimension in the reassociation list. + if (linearizedDimensions[it.index()]) { + llvm::append_range( + offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; + })); + continue; + } + + // Case 3: A single index, but it may be sliced. + offsetsSizesAndStrides.push_back(sliceParams[it.index()]); + } + return offsetsSizesAndStrides; +} + +SmallVector +ExtractShapeExtractSliceBuilder::getInsertSliceParams(ValueRange tileIndices) { + MLIRContext *ctx = tileIndices[0].getContext(); + auto one = IntegerAttr::get(IndexType::get(ctx), 0); + auto zero = IntegerAttr::get(IndexType::get(ctx), 1); + SmallVector insertParams; + insertParams.reserve(linearizedDimensions.size()); + unsigned loopIdx = 0; + for (unsigned i = 0; i < linearizedDimensions.size(); i++) { + if (linearizedDimensions[i] && slicedDimensions[i]) { + insertParams.push_back(Range{tileIndices[loopIdx++], one, one}); + continue; + } + insertParams.push_back(Range{zero, sliceParams[i].size, one}); + } + return insertParams; +} diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -17,6 +17,21 @@ /// Include the definitions of the loop-like interfaces. #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" +std::tuple, SmallVector, + SmallVector> +mlir::getOffsetsSizesAndStrides(ArrayRef ranges) { + SmallVector offsets, sizes, strides; + offsets.reserve(ranges.size()); + sizes.reserve(ranges.size()); + strides.reserve(ranges.size()); + for (const auto &[offset, size, stride] : ranges) { + offsets.push_back(offset); + sizes.push_back(size); + strides.push_back(stride); + } + return std::make_tuple(offsets, sizes, strides); +} + LogicalResult mlir::verifyListOfOperandsOrIntegers( Operation *op, StringRef name, unsigned numElements, ArrayAttr attr, ValueRange values, llvm::function_ref isDynamic) { diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -0,0 +1,164 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH + +func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32> + return %slice : tensor<20x11xf32> +} + +// CHECK: func.func @extract_slice_static(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + +// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: +// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index +// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index +// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// FOREACH-DAG: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) +// FOREACH: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[iv]] into (%[[c3]], %[[c5]], %[[c7]] +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// FOREACH: perform_concurrently +// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[iv]], 0] [1, 11] [1, 1] : +// FOREACH: return %[[tile]] + +// ----- + + +func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32> + return %slice : tensor<10x5xf32> +} + +// CHECK: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2 + 13)> +// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c10]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[inputIv:.+]] = affine.apply #[[$map0]](%[[iv]]) +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + + +// ----- + + +func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 * 2 + s0)> +// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> +// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[sz]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[inputIv:.+]] = affine.apply #[[map0]](%[[iv]])[%[[lb]]] +// CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index %[[inputIv]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] : + +// ----- + + +func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map0:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// CHECK-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[c0]] to %[[sz1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) +// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[c0]] to %[[sz2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) +// CHECK: %[[inputIv1:.+]] = affine.apply #[[map0:.+]](%[[iv1]])[%[[lb1]]] +// CHECK: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[inputIv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[inputIv2:.+]] = affine.apply #[[map0:.+]](%[[iv2]])[%[[lb2]]] +// CHECK: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[inputIv2]] into (%[[c11]], %[[d4]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: scf.yield %[[tile2]] : +// CHECK: return %[[tile1]] : + +// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index +// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index +// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// FOREACH-DAG: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) +// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]] +// FOREACH: %[[multiIndex1:.+]]:3 = affine.delinearize_index %[[iv1]] into (%[[c3]], %[[d1]], %[[d2]]) : +// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]] +// FOREACH: %[[multiIndex2:.+]]:2 = affine.delinearize_index %[[iv2]] into (%[[c11]], %[[d4]]) : +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// FOREACH: perform_concurrently +//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[tid1]], %[[tid2]]] [1, 1] [1, 1] : + +// ----- + +// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this +// only works for static shapes. + +// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>, +func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: %[[multiIndex:.+]]:3 = affine.delinearize_index + // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] + return %slice : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRLinalgDialect MLIRPass MLIRSCFDialect MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -28,7 +29,8 @@ TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { @@ -49,6 +51,19 @@ *this, "test-fold-constant-extract-slice", llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + + Option testRewriteExtractSliceWithTiledCollapseShape{ + *this, "test-rewrite-extract-slice-from-collapse-shape", + llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " + "with loop nest"), + llvm::cl::init(false)}; + + Option useForeach{ + *this, "use-foreach", + llvm::cl::desc( + "Use the scf.foreach_thread operation when generating loop nests for " + "the extract_slice of collapse_shape pattern"), + llvm::cl::init(false)}; }; } // namespace @@ -74,12 +89,115 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +namespace { +/// Pattern to swap a `tensor.extract_slice` with its producer when the +/// producer is a `tensor.collapse_shape`. The `tensor.extract_slice` is +/// replaced by a loop nest that stitches together the desired tile by +/// iterating over the linearized slice dimensions that cannot be represented +/// as a rectangular slice of the source tensor (see above example). +struct RewriteExtractSliceWithTiledCollapseShape + : public OpRewritePattern { + RewriteExtractSliceWithTiledCollapseShape(MLIRContext *context, + bool useScfForeach) + : OpRewritePattern(context), + useScfForeach(useScfForeach) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override; + + bool useScfForeach; +}; +} // namespace + +LogicalResult RewriteExtractSliceWithTiledCollapseShape::matchAndRewrite( + tensor::ExtractSliceOp op, PatternRewriter &rewriter) const { + auto collapseOp = op.getSource().getDefiningOp(); + if (!collapseOp) + return rewriter.notifyMatchFailure( + op, "producer is not a tensor.collapse_shape op"); + + // Materialize the output shape values of the slice operation.a + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); + + // Create the destination tensor using the above values. + Type elementType = op.getSourceType().getElementType(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + Value dest = rewriter.create(op->getLoc(), outputShape, + elementType); + + // Calculate the parameters for the tile loop nest. + FailureOr params = + tensor::ExtractSliceFromCollapseShapeBuilder::create(rewriter, collapseOp, + op); + if (failed(params)) + return rewriter.notifyMatchFailure(op, + "could not calculate tiling parameters"); + + Value result; + Location loc = op.getLoc(); + if (!useScfForeach) { + const unsigned numTiledDims = params->getIterationSpaceSizes().size(); + auto zero = rewriter.create(loc, 0); + auto one = rewriter.create(loc, 1); + SmallVector lbs(numTiledDims, zero); + SmallVector steps(numTiledDims, one); + scf::LoopNest nest = scf::buildLoopNest( + rewriter, loc, lbs, params->getIterationSpaceSizes(), steps, dest, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs, + ValueRange iterArgs) -> scf::ValueVector { + auto [tile, insertParams] = + params->emitExtractSliceFromCollapseShapeLoopNestBody( + nestedBuilder, loc, outputIvs); + + // Insert the slice into the destination. + Value result = nestedBuilder.create( + loc, tile, iterArgs[0], insertParams); + return {result}; + }); + result = nest.getResults()[0]; + } else { + auto foreachOp = rewriter.create( + loc, /*numThreads=*/params->getIterationSpaceSizes(), + /*threadDimMapping=*/ArrayRef{}, + [&](OpBuilder &nestedBuilder, Location loc, ValueRange outputIvs) { + auto [tile, insertParams] = + params->emitExtractSliceFromCollapseShapeLoopNestBody( + nestedBuilder, loc, outputIvs); + // Insert the slice into the destination. + auto term = nestedBuilder.create(loc); + nestedBuilder.setInsertionPointToStart(term.getBody()); + nestedBuilder.create(loc, tile, dest, + insertParams); + }); + result = foreachOp->getResult(0); + } + + rewriter.replaceOp(op, result); + return success(); +} + +static LogicalResult +applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, + bool useForeach) { + RewritePatternSet patterns(rootOp->getContext()); + patterns.add(rootOp->getContext(), + useForeach); + return applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testRewriteExtractSliceWithTiledCollapseShape) { + if (failed( + applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach))) + return signalPassFailure(); + } } namespace mlir {