diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -32,6 +32,74 @@ /// Populates rewrite patterns that lower `tensor.delinearize_index`. void populateLowerDelinearizeIndexPatterns(RewritePatternSet &patterns); +/// Generates IR required to materialize a slice "through" a +/// `tensor.collapse_shape` by creating a loop nest and populating the slice in +/// `dest` by stitching together different slices of the source tensor. The loop +/// nest contains one loop for each sliced output dimension that maps to +/// multiple source dimensions. +/// Example: +// clang-format off +/// ``` +/// %0 = ... -> tensor<3x7x11x10xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] +/// %2 = tensor.extract_slice %1 [0, 0] [10, 10] [1, 1] +/// ``` +/// The "materialized" slice is equivalent to +/// ``` +/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = ... +/// %3:3 = tensor.delinearize_index %iv (3, 7, 11) +/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] +/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] +/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] +/// } +/// ``` +// clang-format on +/// This function directly creates the materialized slice from offsets and +/// sizes. If `useForeach` is true, then an `scf.foreach_thread` operation will +/// be used instead of an `scf.for` loop nest. +FailureOr materializeSliceFromCollapseShape( + OpBuilder &builder, tensor::CollapseShapeOp op, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, bool useForeach); + +/// A callback type that returns a destination tensor given the size and element +/// type. +using CreateDestTensorFn = + std::function, Type)>; + +/// Pattern to swap a `tensor.extract_slice` with its producer when the +/// producer is a `tensor.collapse_shape`. The `tensor.extract_slice` is +/// replaced by a loop nest that stitches together the desired tile by +/// iterating over the linearized slice dimensions that cannot be represented +/// as a rectangular slice of the source tensor (see above example). +struct SliceCollapseShapeOptions { + + bool useForeach = false; + CreateDestTensorFn createDestTensorFn = nullptr; + + SliceCollapseShapeOptions &setUseForeach(bool useScfForeach) { + useForeach = useScfForeach; + return *this; + } + + SliceCollapseShapeOptions &setCreateDestTensorFunc(CreateDestTensorFn fn) { + createDestTensorFn = std::move(fn); + return *this; + } +}; +struct RewriteExtractSliceWithTiledCollapseShape + : public OpRewritePattern { + RewriteExtractSliceWithTiledCollapseShape(MLIRContext *context, + SliceCollapseShapeOptions options) + : OpRewritePattern(context), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override; + + SliceCollapseShapeOptions options; +}; + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp DelinearizeIndex.cpp + ExtractSliceFromReshape.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -0,0 +1,370 @@ +//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that replace slices of reshape results with +// aggregated slices of the reshape source. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/OpDefinition.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular +/// slice of the collapse_shape output. Try to find which dimensions have been +/// sliced and which dimensions are not sliced (offset = 0, size = dim, size = +/// 1). Note that this conservative as it cannot detect if a dynamic size +/// corresponds to the full tensor dimension or not. +static llvm::SmallBitVector getSlicedDimensions( + OpBuilder &b, CollapseShapeOp op, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) { + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op.getOperation()); + (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &it : llvm::enumerate(llvm::zip(offsets, sizes, strides))) { + auto size = std::get<1>(it.value()); + Optional stride = getConstantIntValue(std::get<2>(it.value())); + Optional offset = getConstantIntValue(std::get<0>(it.value())); + result[it.index()] = + !isEqualConstantIntOrValue(size, reifiedShapes[0][it.index()]) || + (!stride || *stride != 1) || (!offset || *offset != 0); + } + return result; +} + +/// Determine which dimensions are linearized by a collapse shape op. +static llvm::SmallBitVector getLinearizedDimensions(CollapseShapeOp op) { + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &it : llvm::enumerate(op.getReassociationIndices())) { + result[it.index()] = it.value().size() > 1; + } + return result; +} + +/// Get the dimension sizes of a value of RankedTensor type. Only dimensions at +/// indices specified in `dimensions` are returned. +static SmallVector +getShapeDimSizes(OpBuilder &b, Location loc, Value rankedTensor, + ArrayRef dimensions) { + SmallVector basis; + for (auto idx : dimensions) { + Value idxValue = b.create(loc, idx); + basis.push_back(b.createOrFold(loc, rankedTensor, idxValue)); + } + return basis; +} + +namespace { +struct CollapseShapeSliceInternal { +protected: + CollapseShapeOp collapseOp; + const SmallVector reassociationIndices; + + const SmallVector sliceOffsets; + const SmallVector sliceSizes; + const SmallVector sliceStrides; + + const llvm::SmallBitVector linearizedOutputDims; + const llvm::SmallBitVector slicedOutputDims; + SmallVector tiledDimensions; + +public: + /// Construct the class using a `tensor.collapse_shape` and slice offsets, + /// sizes, and strides. + CollapseShapeSliceInternal(OpBuilder &b, CollapseShapeOp op, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) + : collapseOp(op), reassociationIndices(op.getReassociationIndices()), + sliceOffsets(offsets.begin(), offsets.end()), + sliceSizes(sizes.begin(), sizes.end()), + sliceStrides(strides.begin(), strides.end()), + linearizedOutputDims(getLinearizedDimensions(collapseOp)), + slicedOutputDims(getSlicedDimensions(b, op, sliceOffsets, sliceSizes, + sliceStrides)) { + + for (int64_t i = 0; i < op.getResultType().getRank(); i++) { + if (linearizedOutputDims[i] && slicedOutputDims[i]) + tiledDimensions.push_back(i); + } + } + + /// Materialize a slice from a `tensor.collapse_shape` by using `scf.for` to + /// sub-tile linearized dimensions. + FailureOr buildWithScfFor(OpBuilder &b, Location loc, Value dest); + + /// Materialize a slice from a `tensor.collapse_shape` by using + /// `scf.foreach_thread` to sub-tile linearized dimensions. + FailureOr buildWithScfForeach(OpBuilder &b, Location loc, Value dest); + +protected: + /// Fill in the body of the generated loop nest. `ivs` contains one Value for + /// each tiled output dimension that corresponds to a set of linearized and + /// sliced input dimensions. The `ivs` values should be treated as the + /// position within the output dimension to delinearize. `insertOffsets` and + /// `insertSizes` are filled with the output positions where the sub-slice + /// should be inserted into the destination tensor. + Value buildBody(OpBuilder &b, Location loc, ValueRange ivs, Value insertDest, + SmallVector &insertOffsets, + SmallVector &insertSizes); + + /// Use the given `multiIndices` for the collapsed and tiled dims to + /// extract a subset (sub-tile) of the slice from the source of the + /// `tensor.collapse_op`. The subset extracted is equivalent to a tile from + /// the source where linearized and sliced dimensions have been tiled by 1. + Value extractSubTileFromSource(OpBuilder &b, Location loc, + ArrayRef multiIndices); +}; +} // namespace + +FailureOr CollapseShapeSliceInternal::buildWithScfFor(OpBuilder &b, + Location loc, + Value dest) { + // Create the bounds for the loop nest to be akin + // to "for(i = offset; i < size * stride; i += stride)" + SmallVector lbs; + SmallVector ubs; + SmallVector steps; + AffineExpr s0, s1, s2; + bindSymbols(b.getContext(), s0, s1, s2); + for (int64_t idx : tiledDimensions) { + Value strideVal = + getValueOrCreateConstantIndexOp(b, loc, sliceStrides[idx]); + lbs.push_back(getValueOrCreateConstantIndexOp(b, loc, sliceOffsets[idx])); + ubs.push_back(makeComposedAffineApply( + b, loc, s0 + s1 * s2, + {lbs.back(), getValueOrCreateConstantIndexOp(b, loc, sliceSizes[idx]), + strideVal})); + steps.push_back(strideVal); + } + scf::LoopNest nest = scf::buildLoopNest( + b, loc, lbs, ubs, steps, dest, + [this](OpBuilder &nestedBuilder, Location loc, ValueRange ivs, + ValueRange iterArgs) -> scf::ValueVector { + // Fill in the body. + SmallVector insertOffsets, insertSizes; + Value slice = buildBody(nestedBuilder, loc, ivs, iterArgs[0], + insertOffsets, insertSizes); + + // Insert the sub-tile into the iteration arg. + Value result = nestedBuilder.create( + loc, slice, iterArgs[0], insertOffsets, insertSizes, + /*strides=*/ + SmallVector(sliceOffsets.size(), + nestedBuilder.getIndexAttr(1))); + return {result}; + }); + return nest.getResults()[0]; +} + +FailureOr CollapseShapeSliceInternal::buildWithScfForeach(OpBuilder &b, + Location loc, + Value dest) { + // For reach sliced and linearized dimension, set the number of threads to be + // the size of the slice in that dim. + SmallVector sizes; + for (int64_t idx : tiledDimensions) + sizes.push_back(sliceSizes[idx]); + + auto foreachOp = b.create( + loc, /*numThreads=*/getValueOrCreateConstantIndexOp(b, loc, sizes), + /*threadDimMapping=*/ArrayRef{}, + [this, dest](OpBuilder &nestedBuilder, Location loc, + ValueRange threadIds) { + // Within the body of the foreach op, create the adjusted induction + // variables to map the thread index iv to a position within the slice + // of the linearized dimension. + AffineExpr d0, s0, s1; + bindDims(nestedBuilder.getContext(), d0); + bindSymbols(nestedBuilder.getContext(), s0, s1); + SmallVector ivs; + for (const auto &it : llvm::enumerate(tiledDimensions)) { + ivs.push_back(makeComposedAffineApply( + nestedBuilder, loc, s0 + d0 * s1, + {threadIds[it.index()], + getValueOrCreateConstantIndexOp(nestedBuilder, loc, + sliceOffsets[it.value()]), + getValueOrCreateConstantIndexOp(nestedBuilder, loc, + sliceStrides[it.value()])})); + } + + // Populate the rest of the body. + SmallVector insertOffsets, insertSizes; + Value slice = buildBody(nestedBuilder, loc, ivs, dest, insertOffsets, + insertSizes); + + // Create the terminator and insert. + auto term = nestedBuilder.create(loc); + nestedBuilder.setInsertionPointToStart(term.getBody()); + nestedBuilder.create( + loc, slice, dest, insertOffsets, insertSizes, + /*strides=*/ + SmallVector(insertOffsets.size(), + nestedBuilder.getIndexAttr(1))); + }); + + return foreachOp->getResult(0); +} + +Value CollapseShapeSliceInternal::extractSubTileFromSource( + OpBuilder &b, Location loc, ArrayRef multiIndices) { + int64_t loopIdx = 0; + RankedTensorType srcType = collapseOp.getSrcType(); + SmallVector extractOffsets; + SmallVector extractSizes; + SmallVector extractStrides; + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: De-linearized dimensions that have also been sliced. These a + // size of 1 because we are iterating over these dimensions. The offset + // is exactly the de-linearized multi index created from the iv's. + if (slicedOutputDims[it.index()] && linearizedOutputDims[it.index()]) { + extractSizes.append(it.value().size(), b.getIndexAttr(1)); + extractOffsets.append(llvm::to_vector(llvm::map_range( + multiIndices[loopIdx++], + [&](Value v) -> OpFoldResult { return getAsOpFoldResult(v); }))); + extractStrides.append(it.value().size(), b.getIndexAttr(1)); + continue; + } + + // Case 2: One or possibly multiple combined input dimensions, but we + // have proven that these are not sliced. In this case we just take the + // full extent of each index in the list. + if (linearizedOutputDims[it.index()]) { + for (int64_t srcIndex : it.value()) { + if (!srcType.isDynamicDim(srcIndex)) + extractSizes.push_back(b.getIndexAttr(srcType.getDimSize(srcIndex))); + else + extractSizes.push_back(b.createOrFold( + loc, collapseOp.src(), + b.create(loc, srcIndex))); + extractOffsets.push_back(b.getIndexAttr(0)); + extractStrides.push_back(b.getIndexAttr(1)); + } + continue; + } + + // Case 3: A single index, but it may be sliced. + extractSizes.push_back(sliceSizes[it.index()]); + extractOffsets.push_back(sliceOffsets[it.index()]); + extractStrides.push_back(sliceStrides[it.index()]); + } + return b.create(loc, collapseOp.src(), extractOffsets, + extractSizes, extractStrides); +} + +/// Fill in the body of the generated loop nest. `ivs` contains one Value for +/// each tiled output dimension that corresponds to a set of linearized and +/// sliced input dimensions. The `ivs` value then should be treated as the +/// position within the output tensor dimension to gather along. +Value CollapseShapeSliceInternal::buildBody( + OpBuilder &b, Location loc, ValueRange ivs, Value insertDest, + SmallVector &insertOffsets, + SmallVector &insertSizes) { + // Create the de-linearized multi indices at the start of + // each loop body. + SmallVector multiIndices; + for (const auto &it : llvm::enumerate(tiledDimensions)) { + SmallVector basis = + getShapeDimSizes(b, collapseOp->getLoc(), collapseOp.src(), + reassociationIndices[it.value()]); + auto delinearizeOp = + b.create(loc, ivs[it.index()], basis); + multiIndices.push_back(delinearizeOp->getResults()); + } + + // Extract a sub-slice from the source using multi-indices. + Value subTileResult = extractSubTileFromSource(b, loc, multiIndices); + + // Collapse the dimensions of the source slice back down. + Value collapsedResult = b.create( + loc, subTileResult, reassociationIndices); + + // Calculate the position where the sub-slice should be inserted into the + // destination tensor. + int64_t loopIdx = 0; + for (int64_t outDimIdx = 0, e = collapseOp.getResultType().getRank(); + outDimIdx < e; outDimIdx++) { + // Case 1: Linearized dimensions that have been sliced. + // The insert size is 1, and the offset is the iv. + if (linearizedOutputDims[outDimIdx] && slicedOutputDims[outDimIdx]) { + insertOffsets.push_back(ivs[loopIdx++]); + insertSizes.push_back(b.getIndexAttr(1)); + continue; + } + // Case 2: Otherwise, the insert is the full shape of + // the iteration argument dimension, because this output + // dimension is not being iterated over in the loop + // nest. + insertOffsets.push_back(b.getIndexAttr(0)); + RankedTensorType iterArgsType = + insertDest.getType().cast(); + if (iterArgsType.isDynamicDim(outDimIdx)) + insertSizes.push_back(b.createOrFold( + loc, insertDest, + b.createOrFold(loc, outDimIdx))); + else + insertSizes.push_back(b.getIndexAttr(iterArgsType.getDimSize(outDimIdx))); + } + return collapsedResult; +} + +FailureOr mlir::tensor::materializeSliceFromCollapseShape( + OpBuilder &b, tensor::CollapseShapeOp op, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, bool useForeach) { + CollapseShapeSliceInternal sliceBuilder(b, op, offsets, sizes, strides); + return useForeach ? sliceBuilder.buildWithScfForeach(b, op->getLoc(), dest) + : sliceBuilder.buildWithScfFor(b, op.getLoc(), dest); +} + +LogicalResult RewriteExtractSliceWithTiledCollapseShape::matchAndRewrite( + tensor::ExtractSliceOp op, PatternRewriter &rewriter) const { + auto collapseOp = op.getSource().getDefiningOp(); + if (!collapseOp) + return rewriter.notifyMatchFailure( + op, "producer is not a tensor.collapse_shape op"); + + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(op.getOperation()); + if (failed( + reifyShapedTypeInterface.reifyResultShapes(rewriter, reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); + + // Create the destination tensor using the above vales. + Value dest; + Type elementType = op.getSourceType().getElementType(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + if (options.createDestTensorFn) + dest = options.createDestTensorFn(rewriter, op->getLoc(), outputShape, + elementType); + else + dest = rewriter.create(op->getLoc(), outputShape, + elementType); + + FailureOr result = materializeSliceFromCollapseShape( + rewriter, collapseOp, dest, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides(), options.useForeach); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to extract slice from tensor.collapse_shape"); + rewriter.replaceOp(op, *result); + return success(); +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -0,0 +1,160 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns="test-rewrite-extract-slice-from-collapse-shape use-foreach" %s | FileCheck %s --check-prefix=FOREACH + +func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32> + return %slice : tensor<20x11xf32> +} + +// CHECK: func.func @extract_slice_static(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + +// FOREACH: func.func @extract_slice_static(%[[arg0:.+]]: +// FOREACH-DAG: %[[c20:.+]] = arith.constant 20 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c5:.+]] = arith.constant 5 : index +// FOREACH-DAG: %[[c7:.+]] = arith.constant 7 : index +// FOREACH-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// FOREACH-DAG: %[[tile:.+]] = scf.foreach_thread (%[[iv:.+]]) in (%[[c20]]) +// FOREACH: %[[multiIndex:.+]]:3 = tensor.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// FOREACH: perform_concurrently +// FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[iv]], 0] [1, 11] [1, 1] : +// FOREACH: return %[[tile]] + +// ----- + + +func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32> + return %slice : tensor<10x5xf32> +} + +// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]: +// CHECK-DAG: %[[c13:.+]] = arith.constant 13 : index +// CHECK-DAG: %[[c33:.+]] = arith.constant 33 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c13]] to %[[c33]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + + +// ----- + + +func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index) +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK: %[[ub:.+]] = affine.apply #[[map1]]()[%[[lb]], %[[sz]]] +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[lb]] to %[[ub]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK-DAG: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> +// CHECK-DAG: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %arg3(%[[c3]], %[[d1]], %[[d2]] : index, index, index) +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] : + +// ----- + + +func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[c11:.+]] = arith.constant 11 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK: %[[ub1:.+]] = affine.apply #[[map1:.+]]()[%[[lb1]], %[[sz1]]] +// CHECK: %[[ub2:.+]] = affine.apply #[[map1:.+]]()[%[[lb2]], %[[sz2]]] +// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[lb1]] to %[[ub1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) +// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[lb2]] to %[[ub2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) +// CHECK-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// CHECK-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// CHECK: %[[multiIndex1:.+]]:3 = tensor.delinearize_index %[[iv1]](%[[c3]], %[[d1]], %[[d2]] : +// CHECK: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// CHECK: %[[multiIndex2:.+]]:2 = tensor.delinearize_index %[[iv2]](%[[c11]], %[[d4]] : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: scf.yield %[[tile2]] : +// CHECK: return %[[tile1]] : + +// FOREACH: #[[map1:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// FOREACH: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// FOREACH-DAG: %[[c1:.+]] = arith.constant 1 : index +// FOREACH-DAG: %[[c2:.+]] = arith.constant 2 : index +// FOREACH-DAG: %[[c3:.+]] = arith.constant 3 : index +// FOREACH-DAG: %[[c4:.+]] = arith.constant 4 : index +// FOREACH-DAG: %[[c11:.+]] = arith.constant 11 : index +// FOREACH: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// FOREACH: %[[tile1:.+]] = scf.foreach_thread (%[[tid1:.+]], %[[tid2:.+]]) in (%[[sz1]], %[[sz2]]) +// FOREACH-DAG: %[[iv1:.+]] = affine.apply #[[map1]](%[[tid1]])[%[[lb1]]] +// FOREACH-DAG: %[[iv2:.+]] = affine.apply #[[map1]](%[[tid2]])[%[[lb2]]] +// FOREACH-DAG: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// FOREACH-DAG: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// FOREACH: %[[multiIndex1:.+]]:3 = tensor.delinearize_index %[[iv1]](%[[c3]], %[[d1]], %[[d2]] : +// FOREACH: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// FOREACH: %[[multiIndex2:.+]]:2 = tensor.delinearize_index %[[iv2]](%[[c11]], %[[d4]] : +// FOREACH: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// FOREACH: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// FOREACH: perform_concurrently +//FOREACH-NEXT: tensor.parallel_insert_slice %[[sliceFlat]] into %[[init]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : + +// ----- + +// Verifies that a linearized dimension that is not sliced does not generate a loop. Note that this +// only works for static shapes. + +// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>, +func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index + // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] + return %slice : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRLinalgDialect MLIRPass MLIRSCFDialect MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -28,7 +29,8 @@ TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { @@ -49,6 +51,19 @@ *this, "test-fold-constant-extract-slice", llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + + Option testRewriteExtractSliceWithTiledCollapseShape{ + *this, "test-rewrite-extract-slice-from-collapse-shape", + llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " + "with loop nest"), + llvm::cl::init(false)}; + + Option useForeach{ + *this, "use-foreach", + llvm::cl::desc( + "Use the scf.foreach_thread operation when generating loop nests for " + "the extract_slice of collapse_shape pattern"), + llvm::cl::init(false)}; }; } // namespace @@ -74,12 +89,23 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp, + bool useForeach) { + RewritePatternSet patterns(rootOp->getContext()); + auto options = tensor::SliceCollapseShapeOptions().setUseForeach(useForeach); + patterns.add( + rootOp->getContext(), options); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testRewriteExtractSliceWithTiledCollapseShape) + applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach); } namespace mlir {