diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -29,6 +29,74 @@ FailureOr replaceExtractSliceWithTiledProducer( OpBuilder &builder, tensor::ExtractSliceOp sliceOp, OpResult producerOp); +/// Generates IR required to materialize a slice "through" a +/// `tensor.collapse_shape` by creating a loop nest and populating the slice in +/// `dest` by stitching together different slices of the source tensor. The loop +/// nest contains one loop for each sliced output dimension that maps to +/// multiple source dimensions. +/// Example: +// clang-format off +/// ``` +/// %0 = ... -> tensor<3x7x11x10xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] +/// %2 = tensor.extract_slice %1 [0, 0] [10, 10] [1, 1] +/// ``` +/// The "materialized" slice is equivalent to +/// ``` +/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = ... +/// %3:3 = arith.delinearize_index %iv (3, 7, 11) +/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] +/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] +/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] +/// } +/// ``` +// clang-format on +/// This function directly creates the materialized slice from offsets and +/// sizes. If `useForeach` is true, then an `scf.foreach_thread` operation will +/// be used instead of an `scf.for` loop nest. +FailureOr materializeSliceFromCollapseShape( + RewriterBase &builder, tensor::CollapseShapeOp op, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, bool useForeach); + +/// A callback type that returns a destination tensor given the size and element +/// type. +using CreateDestTensorFn = + std::function, Type)>; + +/// Pattern to swap a `tensor.extract_slice` with its producer when the +/// producer is a `tensor.collapse_shape`. The `tensor.extract_slice` is +/// replaced by a loop nest that stitches together the desired tile by +/// iterating over the linearized slice dimensions that cannot be represented +/// as a rectangular slice of the source tensor (see above example). +struct SliceCollapseShapeOptions { + + bool useForeach = false; + CreateDestTensorFn createDestTensorFn = nullptr; + + SliceCollapseShapeOptions &setUseForeach(bool useScfForeach) { + useForeach = useScfForeach; + return *this; + } + + SliceCollapseShapeOptions &setCreateDestTensorFunc(CreateDestTensorFn fn) { + createDestTensorFn = std::move(fn); + return *this; + } +}; +struct RewriteExtractSliceWithTiledCollapseShape + : public OpRewritePattern { + RewriteExtractSliceWithTiledCollapseShape(MLIRContext *context, + SliceCollapseShapeOptions options) + : OpRewritePattern(context), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override; + + SliceCollapseShapeOptions options; +}; + } // namespace tensor } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h --- a/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h @@ -16,6 +16,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/StringRef.h" @@ -373,6 +374,51 @@ } }; +/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular +/// slice of the collapse_shape output. Try to find which dimensions have been +/// sliced and which dimensions are not sliced (offset = 0, size = dim, size = +/// 1). Note that this conservative as it cannot detect if a dynamic size +/// corresponds to the full tensor dimension or not. +llvm::SmallBitVector +getSlicedDimensions(ArrayRef materializedOutputShape, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides); + +/// Determine which dimensions are linearized by a `tensor.collapse_shape` op by +/// inspecting its reassociation indices. +llvm::SmallBitVector +getLinearizedDimensions(ArrayRef reassociationIndices); + +/// This function assists in taking the parameters of a strided slice operation +/// (offset, size, stride) that acts on the output of a `tensor.collapse_shape` +/// operation and returns a mapping to compute the equivalent strided slice +/// parameters (offset, size, stride) for a slice of the input to the +/// reshape operation. A mapping is returned because a strided slice of the +/// output of a collapse shape cannot in general be translated into a single +/// strided slice of the input. Instead, a function `(D_0, D_1,.. D_{n-1}) -> +/// (offsets, sizes, strides)` is returned where `n` is equal to the number of +/// output dimensions that both correspond to multiple input dimensions and are +/// sliced by the provided strided slice parameters. Each `D_i` is then a tuple +/// that must represent a valid multi-index for a corresponding linearized +/// dimension. Concretely, if the input shape is `[s0, s1]` and the output shape +/// is `[s0 * s1]`, then `D_0 = (d0, d1)`, corresponding to coordinates for the +/// two collapsed input dimensions of size `s0` and `s1`. The multi-index `D_0` +/// must then satisfy `0 <= d1 * s0 + d1 < s0 * s1`. +using CollapseShapeSourceExtractSliceParametersFunc = std::function /*multiIndices*/, + SmallVector & /*outputOffsets*/, + SmallVector & /*outputSizes*/, + SmallVector & /*outputStrides*/)>; +FailureOr +getCollapseShapeExtractSliceParamMapping( + const SmallVector &reassociationIndices, + const SmallVector &inputTensorShape, + const SmallVector &outputTensorShape, + const SmallVector &sliceOffsets, + const SmallVector &sliceSizes, + const SmallVector &sliceStrides); + } // namespace mlir #endif // MLIR_DIALECT_UTILS_RESHAPEOPSUTILS_H diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRTensorTransforms BufferizableOpInterfaceImpl.cpp Bufferize.cpp + ExtractSliceFromReshape.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -0,0 +1,339 @@ +//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that replace slices of reshape results with +// aggregated slices of the reshape source. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Get the dimension size of a value of RankedTensor type at the +OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, Value rankedTensor, + int64_t dimIdx) { + RankedTensorType tensorType = rankedTensor.getType().cast(); + if (!tensorType.isDynamicDim(dimIdx)) { + return b.getIndexAttr(tensorType.getDimSize(dimIdx)); + } + Value idxValue = b.create(loc, dimIdx); + return b.createOrFold(loc, rankedTensor, idxValue); +} + +/// Get the dimension sizes of a value of RankedTensor type. Only dimensions at +/// indices specified in `dimensions` are returned. +static SmallVector +getShapeDimSizes(OpBuilder &b, Location loc, Value rankedTensor, + ArrayRef dimensions) { + SmallVector dimSizes; + for (auto idx : dimensions) + dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, idx)); + return dimSizes; +} + +/// Get all the dimension sizes of a value of RankedTensor type. +static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, + Value rankedTensor) { + SmallVector dimSizes; + RankedTensorType tensorType = rankedTensor.getType().cast(); + for (unsigned i = 0; i < tensorType.getRank(); i++) + dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); + return dimSizes; +} + +namespace { + +/// Implementation details for `materializeSliceFromCollapseShape`. +struct CollapseShapeSliceInternal { +public: + CollapseShapeSliceInternal() = delete; + + /// Construct the class using a `tensor.collapse_shape` and strided slice + /// parameters. + CollapseShapeSliceInternal( + CollapseShapeOp op, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides, + const llvm::SmallBitVector &linearizedOutputDims, + const llvm::SmallBitVector &slicedOutputDims, + CollapseShapeSourceExtractSliceParametersFunc sliceParamMapping) + : collapseOp(op), reassociationIndices(op.getReassociationIndices()), + sliceOffsets(offsets.begin(), offsets.end()), + sliceSizes(sizes.begin(), sizes.end()), + sliceStrides(strides.begin(), strides.end()), + linearizedOutputDims(linearizedOutputDims), + slicedOutputDims(slicedOutputDims), + sliceParamMapping(std::move(sliceParamMapping)) { + + for (int64_t i = 0; i < op.getResultType().getRank(); i++) { + if (linearizedOutputDims[i] && slicedOutputDims[i]) + tiledDimensions.push_back(i); + } + } + + /// Materialize a slice from a `tensor.collapse_shape` by using `scf.for` to + /// sub-tile linearized dimensions. + FailureOr buildWithScfFor(RewriterBase &b, Location loc, Value dest); + + /// Materialize a slice from a `tensor.collapse_shape` by using + /// `scf.foreach_thread` to sub-tile linearized dimensions. + FailureOr buildWithScfForeach(RewriterBase &b, Location loc, + Value dest); + +protected: + CollapseShapeOp collapseOp; + const SmallVector reassociationIndices; + + /// References to the slice parameters acting on the CollapseShapeOp output. + /// These should persist through the lifetime of this object. + ArrayRef sliceOffsets; + ArrayRef sliceSizes; + ArrayRef sliceStrides; + + const llvm::SmallBitVector linearizedOutputDims; + const llvm::SmallBitVector slicedOutputDims; + SmallVector tiledDimensions; + + CollapseShapeSourceExtractSliceParametersFunc sliceParamMapping; + + /// Fill in the body of the generated loop nest. `ivs` contains one Value for + /// each tiled output dimension that corresponds to a set of linearized and + /// sliced input dimensions. The `ivs` values should be treated as the + /// position within the output dimension to delinearize. `insertOffsets` and + /// `insertSizes` are filled with the output positions where the sub-slice + /// should be inserted into the destination tensor. + Value buildBody(OpBuilder &b, Location loc, ValueRange ivs, Value insertDest, + SmallVector &insertOffsets, + SmallVector &insertSizes); +}; +} // namespace + +FailureOr CollapseShapeSliceInternal::buildWithScfFor(RewriterBase &b, + Location loc, + Value dest) { + // Create the bounds for the loop nest to be akin to "for (i = offset; i < + // size * stride; i += stride)" + SmallVector lbs, ubs, steps; + AffineExpr d0, s0, s1; + bindDims(b.getContext(), d0); + bindSymbols(b.getContext(), s0, s1); + for (int64_t idx : tiledDimensions) { + lbs.push_back(sliceOffsets[idx]); + ubs.push_back(makeComposedFoldedAffineApply( + b, loc, s0 + d0 * s1, + {sliceSizes[idx], lbs.back(), sliceStrides[idx]})); + steps.push_back(sliceStrides[idx]); + } + scf::LoopNest nest = scf::buildLoopNest( + b, loc, getValueOrCreateConstantIndexOp(b, loc, lbs), + getValueOrCreateConstantIndexOp(b, loc, ubs), + getValueOrCreateConstantIndexOp(b, loc, steps), dest, + [this](OpBuilder &nestedBuilder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + // Fill in the body. + SmallVector insertOffsets, insertSizes; + Value slice = buildBody(nestedBuilder, loc, ivs, iterArgs[0], + insertOffsets, insertSizes); + + // Insert the sub-tile into the iteration arg. + Value result = nestedBuilder.create( + loc, slice, iterArgs[0], insertOffsets, insertSizes, + /*strides=*/ + SmallVector(sliceOffsets.size(), + nestedBuilder.getIndexAttr(1))); + return {result}; + }); + return nest.getResults()[0]; +} + +FailureOr +CollapseShapeSliceInternal::buildWithScfForeach(RewriterBase &b, Location loc, + Value dest) { + // For reach sliced and linearized dimension, set the number of threads to be + // the size of the slice in that dim. + SmallVector sizes; + for (int64_t idx : tiledDimensions) + sizes.push_back(sliceSizes[idx]); + + auto foreachOp = b.create( + loc, /*numThreads=*/getValueOrCreateConstantIndexOp(b, loc, sizes), + /*threadDimMapping=*/ArrayRef{}, + [this, dest](OpBuilder &nestedBuilder, Location loc, + ValueRange threadIds) { + // Within the body of the foreach op, create the adjusted induction + // variables to map the thread index iv to a position within the slice + // of the linearized dimension. + AffineExpr d0, s0, s1; + bindDims(nestedBuilder.getContext(), d0); + bindSymbols(nestedBuilder.getContext(), s0, s1); + SmallVector ivs; + for (const auto &it : llvm::enumerate(tiledDimensions)) { + ivs.push_back(makeComposedAffineApply( + nestedBuilder, loc, s0 + d0 * s1, + {threadIds[it.index()], + getValueOrCreateConstantIndexOp(nestedBuilder, loc, + sliceOffsets[it.value()]), + getValueOrCreateConstantIndexOp(nestedBuilder, loc, + sliceStrides[it.value()])})); + } + + // Populate the rest of the body. + SmallVector insertOffsets, insertSizes; + Value slice = buildBody(nestedBuilder, loc, ivs, dest, insertOffsets, + insertSizes); + + // Create the terminator and insert. + auto term = nestedBuilder.create(loc); + nestedBuilder.setInsertionPointToStart(term.getBody()); + nestedBuilder.create( + loc, slice, dest, insertOffsets, insertSizes, + /*strides=*/ + SmallVector(insertOffsets.size(), + nestedBuilder.getIndexAttr(1))); + }); + + return foreachOp->getResult(0); +} + +/// Fill in the body of the generated loop nest. `ivs` contains one Value for +/// each tiled output dimension that corresponds to a set of linearized and +/// sliced input dimensions. The `ivs` value then should be treated as the +/// position within the output tensor dimension to gather along. +Value CollapseShapeSliceInternal::buildBody( + OpBuilder &b, Location loc, ValueRange ivs, Value insertDest, + SmallVector &insertOffsets, + SmallVector &insertSizes) { + // Create the de-linearized multi indices at the start of + // each loop body. + SmallVector multiIndices; + for (const auto &it : llvm::enumerate(tiledDimensions)) { + SmallVector basis = + getShapeDimSizes(b, collapseOp->getLoc(), collapseOp.src(), + reassociationIndices[it.value()]); + auto delinearizeOp = + b.create(loc, ivs[it.index()], basis); + multiIndices.push_back(delinearizeOp->getResults()); + } + + // Extract a sub-slice from the source using multi-indices. + SmallVector extractOffsets, extractSizes, extractStrides; + sliceParamMapping(b.getContext(), multiIndices, extractOffsets, extractSizes, + extractStrides); + Value subTileResult = b.create( + loc, collapseOp.src(), extractOffsets, extractSizes, extractStrides); + + // Collapse the dimensions of the source slice back down. + Value collapsedResult = b.create( + loc, subTileResult, reassociationIndices); + + // Calculate the position where the sub-slice should be inserted into the + // destination tensor. + int64_t loopIdx = 0; + for (int64_t outDimIdx = 0, e = collapseOp.getResultType().getRank(); + outDimIdx < e; outDimIdx++) { + // Case 1: Linearized dimensions that have been sliced. + // The insert size is 1, and the offset is the iv. + if (linearizedOutputDims[outDimIdx] && slicedOutputDims[outDimIdx]) { + insertOffsets.push_back(ivs[loopIdx++]); + insertSizes.push_back(b.getIndexAttr(1)); + continue; + } + // Case 2: Otherwise, the insert is the full shape of + // the iteration argument dimension, because this output + // dimension is not being iterated over in the loop + // nest. + insertOffsets.push_back(b.getIndexAttr(0)); + RankedTensorType iterArgsType = + insertDest.getType().cast(); + if (iterArgsType.isDynamicDim(outDimIdx)) + insertSizes.push_back(b.createOrFold( + loc, insertDest, + b.createOrFold(loc, outDimIdx))); + else + insertSizes.push_back(b.getIndexAttr(iterArgsType.getDimSize(outDimIdx))); + } + return collapsedResult; +} + +FailureOr mlir::tensor::materializeSliceFromCollapseShape( + RewriterBase &b, tensor::CollapseShapeOp op, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, bool useForeach) { + // Materialize the output shape of the collapse_shape operation. This will + // create IR describing the output shape in terms of the input shape. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + return failure(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + SmallVector reassociationIndices = + op.getReassociationIndices(); + + // Get the parameterized function that maps the given slice parameters to + // slice parameters acting on the source tensor. + FailureOr paramMapping = + getCollapseShapeExtractSliceParamMapping( + reassociationIndices, getShapeDimSizes(b, op->getLoc(), op.getSrc()), + outputShape, llvm::to_vector(offsets), llvm::to_vector(sizes), + llvm::to_vector(strides)); + if (failed(paramMapping)) + return failure(); + + CollapseShapeSliceInternal sliceBuilder( + op, offsets, sizes, strides, + getLinearizedDimensions(reassociationIndices), + getSlicedDimensions(outputShape, offsets, sizes, strides), *paramMapping); + + return useForeach ? sliceBuilder.buildWithScfForeach(b, op->getLoc(), dest) + : sliceBuilder.buildWithScfFor(b, op.getLoc(), dest); +} + +LogicalResult RewriteExtractSliceWithTiledCollapseShape::matchAndRewrite( + tensor::ExtractSliceOp op, PatternRewriter &rewriter) const { + auto collapseOp = op.getSource().getDefiningOp(); + if (!collapseOp) + return rewriter.notifyMatchFailure( + op, "producer is not a tensor.collapse_shape op"); + + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + if (failed(op.reifyResultShapes(rewriter, reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); + + // Create the destination tensor using the above values. + Value dest; + Type elementType = op.getSourceType().getElementType(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + if (options.createDestTensorFn) + dest = options.createDestTensorFn(rewriter, op->getLoc(), outputShape, + elementType); + else + dest = rewriter.create(op->getLoc(), outputShape, + elementType); + + FailureOr result = materializeSliceFromCollapseShape( + rewriter, collapseOp, dest, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides(), options.useForeach); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to extract slice from tensor.collapse_shape"); + rewriter.replaceOp(op, *result); + return success(); +} diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Builders.h" @@ -270,3 +271,88 @@ return !memrefType.getLayout().isIdentity(); return false; } + +llvm::SmallBitVector +mlir::getSlicedDimensions(ArrayRef materializedOutputShape, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + llvm::SmallBitVector mask(materializedOutputShape.size()); + for (const auto &it : llvm::enumerate(llvm::zip(offsets, sizes, strides))) { + Optional stride = getConstantIntValue(std::get<2>(it.value())); + Optional offset = getConstantIntValue(std::get<0>(it.value())); + OpFoldResult size = std::get<1>(it.value()); + mask[it.index()] = + !isEqualConstantIntOrValue(size, materializedOutputShape[it.index()]) || + (!stride || *stride != 1) || (!offset || *offset != 0); + } + return mask; +} + +llvm::SmallBitVector mlir::getLinearizedDimensions( + ArrayRef reassociationIndices) { + llvm::SmallBitVector result(reassociationIndices.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) + result[it.index()] = it.value().size() > 1; + return result; +} + +FailureOr +mlir::getCollapseShapeExtractSliceParamMapping( + const SmallVector &reassociationIndices, + const SmallVector &inputTensorShape, + const SmallVector &outputTensorShape, + const SmallVector &sliceOffsets, + const SmallVector &sliceSizes, + const SmallVector &sliceStrides) { + + llvm::SmallBitVector linearizedOutputDims = + getLinearizedDimensions(reassociationIndices); + llvm::SmallBitVector slicedOutputDims = getSlicedDimensions( + outputTensorShape, sliceOffsets, sliceSizes, sliceStrides); + + CollapseShapeSourceExtractSliceParametersFunc result = + [=](MLIRContext *ctx, ArrayRef multiIndices, + SmallVector &offsets, SmallVector &sizes, + SmallVector &strides) -> void { + int64_t loopIdx = 0; + auto oneAttr = IntegerAttr::get(IndexType::get(ctx), 1); + auto zeroAttr = IntegerAttr::get(IndexType::get(ctx), 0); + offsets.reserve(reassociationIndices.size()); + sizes.reserve(reassociationIndices.size()); + strides.reserve(reassociationIndices.size()); + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: De-linearized dimensions that have also been sliced. These a + // size of 1 because we are iterating over these dimensions. The offset + // is exactly the de-linearized multi index created from the iv's. + if (slicedOutputDims[it.index()] && linearizedOutputDims[it.index()]) { + sizes.append(it.value().size(), oneAttr); + llvm::append_range(offsets, + llvm::map_range(multiIndices[loopIdx++], + [&](Value v) -> OpFoldResult { + return getAsOpFoldResult(v); + })); + strides.append(it.value().size(), oneAttr); + continue; + } + + // Case 2: One or possibly multiple combined input dimensions, but we + // have proven that these are not sliced. In this case we just take the + // full extent of each index in the list. + if (linearizedOutputDims[it.index()]) { + for (auto idx : it.value()) + sizes.push_back(inputTensorShape[idx]); + offsets.append(it.value().size(), zeroAttr); + strides.append(it.value().size(), oneAttr); + continue; + } + + // Case 3: A single index, but it may be sliced. + sizes.push_back(sliceSizes[it.index()]); + offsets.push_back(sliceOffsets[it.index()]); + strides.push_back(sliceStrides[it.index()]); + } + }; + + return result; +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH + +func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32> + return %slice : tensor<20x11xf32> +} + +// CHECK: func.func @extract_slice_static(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = arith.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + +// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: +// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index +// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index +// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// FOREACH-DAG: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) +// FOREACH: %[[multiIndex:.+]]:3 = arith.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// FOREACH: perform_concurrently +// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[iv]], 0] [1, 11] [1, 1] : +// FOREACH: return %[[tile]] + +// ----- + + +func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32> + return %slice : tensor<10x5xf32> +} + +// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]: +// CHECK-DAG: %[[c13:.+]] = arith.constant 13 : index +// CHECK-DAG: %[[c33:.+]] = arith.constant 33 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c13]] to %[[c33]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = arith.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + + +// ----- + + +func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index) +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK: %[[ub:.+]] = affine.apply #[[map1]]()[%[[lb]], %[[sz]]] +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[lb]] to %[[ub]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> +// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> +// CHECK: %[[multiIndex:.+]]:3 = arith.delinearize_index %arg3(%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] : + +// ----- + + +func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK: %[[ub1:.+]] = affine.apply #[[map1:.+]]()[%[[lb1]], %[[sz1]]] +// CHECK: %[[ub2:.+]] = affine.apply #[[map1:.+]]()[%[[lb2]], %[[sz2]]] +// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[lb1]] to %[[ub1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) +// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[lb2]] to %[[ub2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) +// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// CHECK: %[[multiIndex1:.+]]:3 = arith.delinearize_index %[[iv1]](%[[c3]], %[[d1]], %[[d2]]) : +// CHECK: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// CHECK: %[[multiIndex2:.+]]:2 = arith.delinearize_index %[[iv2]](%[[c11]], %[[d4]]) : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: scf.yield %[[tile2]] : +// CHECK: return %[[tile1]] : + +// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index +// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index +// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) +// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]] +// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]] +// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// FOREACH: %[[multiIndex1:.+]]:3 = arith.delinearize_index %[[iv1]](%[[c3]], %[[d1]], %[[d2]]) : +// FOREACH: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// FOREACH: %[[multiIndex2:.+]]:2 = arith.delinearize_index %[[iv2]](%[[c11]], %[[d4]]) : +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// FOREACH: perform_concurrently +//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : + +// ----- + +// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this +// only works for static shapes. + +// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>, +func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: %[[multiIndex:.+]]:3 = arith.delinearize_index + // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] + return %slice : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRLinalgDialect MLIRPass MLIRSCFDialect MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -28,7 +29,8 @@ TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { @@ -49,6 +51,19 @@ *this, "test-fold-constant-extract-slice", llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + + Option testRewriteExtractSliceWithTiledCollapseShape{ + *this, "test-rewrite-extract-slice-from-collapse-shape", + llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " + "with loop nest"), + llvm::cl::init(false)}; + + Option useForeach{ + *this, "use-foreach", + llvm::cl::desc( + "Use the scf.foreach_thread operation when generating loop nests for " + "the extract_slice of collapse_shape pattern"), + llvm::cl::init(false)}; }; } // namespace @@ -74,12 +89,23 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, + bool useForeach) { + RewritePatternSet patterns(rootOp->getContext()); + auto options = tensor::SliceCollapseShapeOptions().setUseForeach(useForeach); + patterns.add( + rootOp->getContext(), options); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testRewriteExtractSliceWithTiledCollapseShape) + applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach); } namespace mlir {