diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -32,6 +32,73 @@ /// Populates rewrite patterns that lower `tensor.delinearize_index`. void populateLowerDelinearizeIndexPatterns(RewritePatternSet &patterns); +/// Generates IR required to materialize a slice "through" a +/// `tensor.collapse_shape` by creating a loop nest and populating the slice in +/// `dest` by stitching together different slices of the source tensor. The loop +/// nest contains one loop for each sliced output dimension that maps to +/// multiple source dimensions. +/// Example: +// clang-format off +/// ``` +/// %0 = ... -> tensor<3x7x11x10xf32> +/// %1 = tensor.collapse_shape %0 [[0, 1, 2], [3]] +/// %2 = tensor.extract_slice %1 [0, 0] [10, 10] [1, 1] +/// ``` +/// The "materialized" slice is equivalent to +/// ``` +/// %2 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = ... +/// %3:3 = tensor.delinearize_index %iv (3, 7, 11) +/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] +/// %6 = tensor.insert_slice %4 into %arg0 [%iv, 0] [1, 10] [1, 1] +/// } +/// ``` +// clang-format on +/// This function directly creates the materialized slice from offsets and +/// sizes. The parameter `loopType` can be `scf.for` or `scf.foreach_thread`. +FailureOr materializeSliceFromCollapseShape( + OpBuilder &builder, tensor::CollapseShapeOp op, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, StringRef loopType); + +/// A callback type that returns a destination tensor given the size and element +/// type. +using CreateDestTensorFn = + std::function, Type)>; + +/// Pattern to swap a `tensor.extract_slice` with its producer when the +/// producer is a `tensor.collapse_shape`. The `tensor.extract_slice` is +/// replaced by a loop nest that stitches together the desired tile by +/// iterating over the linearized slice dimensions that cannot be represented +/// as a rectangular slice of the source tensor (see above example). +struct SliceCollapseShapeOptions { + StringRef loopType; + CreateDestTensorFn createDestTensorFn = nullptr; + + SliceCollapseShapeOptions() : loopType("for") {} + + SliceCollapseShapeOptions &setLoopType(StringRef scfLoopType) { + loopType = scfLoopType; + return *this; + } + + SliceCollapseShapeOptions &setCreateDestTensorFunc(CreateDestTensorFn fn) { + createDestTensorFn = std::move(fn); + return *this; + } +}; +struct RewriteExtractSliceWithTiledCollapseShape + : public OpRewritePattern { + RewriteExtractSliceWithTiledCollapseShape(MLIRContext *context, + SliceCollapseShapeOptions options) + : OpRewritePattern(context), + options(std::move(options)) {} + + LogicalResult matchAndRewrite(tensor::ExtractSliceOp op, + PatternRewriter &rewriter) const override; + + SliceCollapseShapeOptions options; +}; + } // namespace tensor } // namespace mlir diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp DelinearizeIndex.cpp + ExtractSliceFromReshape.cpp SplitPadding.cpp SwapExtractSliceWithProducer.cpp diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshape.cpp @@ -0,0 +1,311 @@ +//===- ExtractSliceFromReshape.cpp - Slice reshape rewrites-------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements rewrites that replace slices of reshape results with +// aggregated slices of the reshape source. +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/OpDefinition.h" + +using namespace mlir; +using namespace mlir::tensor; + +/// Return an OpFoldResult as a Value, assuming it represents an index type. +static Value getAsValue(OpBuilder &b, Location loc, OpFoldResult ofr) { + Optional constValue = getConstantIntValue(ofr); + if (constValue.hasValue()) + return b.create(loc, *constValue); + return ofr.dyn_cast(); +} + +/// The input parameters `offsets`, `sizes`, `strides` specify a rectangular +/// slice of the collapse_shape output. Try to find which dimensions have been +/// sliced and which dimensions are not sliced (offset = 0, size = dim, size = +/// 1). Note that this conservative as it cannot detect if a dynamic size +/// corresponds to the full tensor dimension or not. +static FailureOr getSlicedDimensions( + OpBuilder &b, CollapseShapeOp op, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) { + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op.getOperation()); + if (failed(reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes))) + return failure(); + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &it : llvm::enumerate(llvm::zip(offsets, sizes, strides))) { + auto size = std::get<1>(it.value()); + Optional stride = getConstantIntValue(std::get<2>(it.value())); + Optional offset = getConstantIntValue(std::get<0>(it.value())); + result[it.index()] = + !isEqualConstantIntOrValue(size, reifiedShapes[0][it.index()]) || + (!stride || *stride != 1) || (!offset || *offset != 0); + } + return result; +} + +/// Determine which dimensions are linearized. +static llvm::SmallBitVector getLinearizedDimensions(CollapseShapeOp op) { + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &it : llvm::enumerate(op.getReassociationIndices())) { + result[it.index()] = it.value().size() > 1; + } + return result; +} + +/// Create an empty loop nest using the given loop parameters. All `scf.yield` +/// operations are inserted except for the inner-most loop's yield. +static scf::LoopNest getEmptyLoopNest(OpBuilder &b, Location loc, + ArrayRef nestLowerBounds, + ArrayRef nestUpperBounds, + ArrayRef nestStrides, + Value iterArgInit) { + scf::LoopNest nest; + Value iterArg = iterArgInit; + Location currLoc = loc; + for (const auto &it : + llvm::zip(nestLowerBounds, nestUpperBounds, nestStrides)) { + auto loop = b.create( + currLoc, std::get<0>(it), std::get<1>(it), std::get<2>(it), iterArg, + [&](OpBuilder &nB, Location nLoc, Value nIvs, ValueRange nIterArgs) { + currLoc = nLoc; + iterArg = nIterArgs[0]; + }); + b.setInsertionPointToStart(loop.getBody()); + nest.loops.push_back(loop); + } + for (unsigned i = 0, e = nest.loops.size() - 1; i < e; i++) { + b.setInsertionPointToEnd(nest.loops[i].getBody()); + b.create(nest.loops[i]->getLoc(), + nest.loops[i + 1].getResults()); + } + return nest; +} + +/// Given a `tensor.collapse_shape` op and information regarding which +/// dimensions have been collapsed and which are being tiled, as well as the +/// multi-index elements for each delinearized index, create the +/// tiled form of the matmul. +static FailureOr createCollapseShapeTiledResult( + CollapseShapeOp collapseOp, Value producer, OpBuilder &b, + const llvm::SmallBitVector &slicedOutputDims, + const llvm::SmallBitVector &linearizedOutputDims, + ArrayRef> multiIndices, ArrayRef offsets, + ArrayRef sizes, ArrayRef strides) { + // Construct offsets to extract from the result of the producer op (which is + // the input to the `tensor.collapse_shape`). + SmallVector sliceOffsets; + SmallVector sliceSizes; + SmallVector sliceStrides; + int64_t loopIdx = 0; + RankedTensorType srcType = collapseOp.getSrcType(); + Location loc = collapseOp->getLoc(); + SmallVector reassociationIndices = + collapseOp.getReassociationIndices(); + + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: De-linearized dimensions that have also been sliced. These a + // size of 1 because we are iterating over these dimensions. The offset is + // exactly the de-linearized multi index created from the iv's. + if (slicedOutputDims[it.index()] && linearizedOutputDims[it.index()]) { + sliceSizes.append(it.value().size(), b.getIndexAttr(1)); + sliceOffsets.append(llvm::to_vector(llvm::map_range( + multiIndices[loopIdx++], + [&](Value v) -> OpFoldResult { return getAsOpFoldResult(v); }))); + sliceStrides.append(it.value().size(), b.getIndexAttr(1)); + continue; + } + + // Case 2: One or possible multiple combined input dimensions, but we have + // proven that these are not sliced. In this case we just take the full + // extent of each index in the list. + if (linearizedOutputDims[it.index()]) { + for (int64_t srcIndex : it.value()) { + if (!srcType.isDynamicDim(srcIndex)) + sliceSizes.push_back(b.getIndexAttr(srcType.getDimSize(srcIndex))); + else + sliceSizes.push_back(b.createOrFold( + loc, producer, b.create(loc, srcIndex))); + sliceOffsets.push_back(b.getIndexAttr(0)); + sliceStrides.push_back(b.getIndexAttr(1)); + } + continue; + } + + // Case 3: A single index, but it may be sliced. + sliceSizes.push_back(sizes[it.index()]); + sliceOffsets.push_back(offsets[it.index()]); + sliceStrides.push_back(strides[it.index()]); + } + + Value tileResult = b.create( + loc, producer, sliceOffsets, sliceSizes, sliceStrides); + + // Collapse the dimensions back down. + Value collapsedResult = + b.create(loc, tileResult, reassociationIndices); + return collapsedResult; +} + +FailureOr mlir::tensor::materializeSliceFromCollapseShape( + OpBuilder &b, tensor::CollapseShapeOp collapseOp, Value dest, + ArrayRef offsets, ArrayRef sizes, + ArrayRef strides, StringRef loopType) { + Location loc = collapseOp->getLoc(); + + // Materialize the source shapes. + SmallVector sourceShape; + for (unsigned i = 0; i < collapseOp.getSrcType().getRank(); i++) { + Value dimIdx = b.createOrFold(loc, i); + sourceShape.push_back( + b.createOrFold(loc, collapseOp.getSrc(), dimIdx)); + } + + // Try to find which dimensions are sliced and/or linearized. + FailureOr slicedOutputDims = + getSlicedDimensions(b, collapseOp, offsets, sizes, strides); + if (failed(slicedOutputDims)) + return failure(); + llvm::SmallBitVector linearizedOutputDims = + getLinearizedDimensions(collapseOp); + + // If there are no sliced and linearized dimensions, then we cannot proceed. + llvm::SmallBitVector slicedAndLinearized = + *slicedOutputDims & linearizedOutputDims; + if (!slicedAndLinearized.any()) + return failure(); + + // Create the bounds for the loop nest. + SmallVector nestLowerBounds; + SmallVector nestUpperBounds; + SmallVector nestStrides; + AffineExpr s0, s1, s2; + bindSymbols(b.getContext(), s0, s1, s2); + for (int idx = slicedAndLinearized.find_first(); idx != -1; + idx = slicedAndLinearized.find_next(idx)) { + Value strideVal = getAsValue(b, loc, strides[idx]); + nestLowerBounds.push_back(getAsValue(b, loc, offsets[idx])); + nestUpperBounds.push_back(makeComposedAffineApply( + b, loc, s0 + s1 * s2, + {nestLowerBounds.back(), getAsValue(b, loc, sizes[idx]), strideVal})); + nestStrides.push_back(strideVal); + } + scf::LoopNest nest = getEmptyLoopNest(b, loc, nestLowerBounds, + nestUpperBounds, nestStrides, dest); + + // Create the de-linearized multi indices at the start of each loop body. + SmallVector> multiIndices; + SmallVector reassociationIndices = + collapseOp.getReassociationIndices(); + for (unsigned i = 0, loopIdx = 0; i < reassociationIndices.size(); i++) { + if (!linearizedOutputDims[i] || !(*slicedOutputDims)[i]) + continue; + assert(loopIdx < nest.loops.size()); + auto loop = nest.loops[loopIdx++]; + b.setInsertionPointToStart(loop.getBody()); + Value iv = loop.getInductionVar(); + + SmallVector basis; + for (auto idx : reassociationIndices[i]) { + basis.push_back(b.createOrFold( + loc, collapseOp.src(), b.create(loc, idx))); + } + auto delinOp = b.create(loop->getLoc(), + /*linear_index=*/iv, + /*basis=*/basis); + multiIndices.push_back(llvm::to_vector(llvm::map_range( + delinOp.getResults(), [](OpResult r) -> Value { return r; }))); + } + + // Fill out the first part of the loop body - sub tile for a single + // iteration. + scf::ForOp innerLoop = nest.loops.back(); + b.setInsertionPointToEnd(innerLoop.getBody()); + Value iterArg = innerLoop.getRegionIterArgs()[0]; + FailureOr tiledResult = createCollapseShapeTiledResult( + collapseOp, collapseOp.src(), b, *slicedOutputDims, linearizedOutputDims, + multiIndices, offsets, sizes, strides); + if (failed(tiledResult)) + return failure(); + + // Insert the collapse_shape sub-tile into the iteration argument. + SmallVector insertOffsets; + SmallVector insertSizes; + for (unsigned i = 0, loopIdx = 0; i < reassociationIndices.size(); i++) { + // Case 1: Linearized dimensions that have been sliced. The insert size is + // 1, and the offset is the iv. + if (linearizedOutputDims[i] && (*slicedOutputDims)[i]) { + insertOffsets.push_back(nest.loops[loopIdx++].getInductionVar()); + insertSizes.push_back(b.getIndexAttr(1)); + continue; + } + // Case 2: Otherwise, the insert is the full shape of the iteration + // argument dimension, because this output dimension is not being iterated + // over in the loop nest. + insertOffsets.push_back(b.getIndexAttr(0)); + RankedTensorType iterArgsType = iterArg.getType().cast(); + if (iterArgsType.isDynamicDim(i)) + insertSizes.push_back(b.createOrFold( + loc, iterArg, b.createOrFold(loc, i))); + else + insertSizes.push_back(b.getIndexAttr(iterArgsType.getDimSize(i))); + } + + Value result = b.create( + loc, *tiledResult, iterArg, insertOffsets, insertSizes, + /*strides=*/ + SmallVector(insertOffsets.size(), b.getIndexAttr(1))); + b.create(loc, result); + b.setInsertionPointAfter(collapseOp); + return nest.loops.begin()->getResult(0); + return success(); +} + +LogicalResult RewriteExtractSliceWithTiledCollapseShape::matchAndRewrite( + tensor::ExtractSliceOp op, PatternRewriter &rewriter) const { + auto collapseOp = op.getSource().getDefiningOp(); + if (!collapseOp) + return rewriter.notifyMatchFailure( + op, "producer is not a tensor.collapse_shape op"); + + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + cast(op.getOperation()); + if (failed( + reifyShapedTypeInterface.reifyResultShapes(rewriter, reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify result shapes"); + + // Create the destination tensor using the above vales. + Value dest; + Type elementType = op.getSourceType().getElementType(); + SmallVector outputShape = getAsOpFoldResult(reifiedShapes[0]); + if (options.createDestTensorFn) + dest = options.createDestTensorFn(rewriter, op->getLoc(), outputShape, + elementType); + else + dest = rewriter.create(op->getLoc(), outputShape, + elementType); + + // Materialize the loop nest and replace the `tensor.extract_slice` op. + FailureOr result = materializeSliceFromCollapseShape( + rewriter, collapseOp, dest, op.getMixedOffsets(), op.getMixedSizes(), + op.getMixedStrides(), options.loopType); + if (failed(result)) + return rewriter.notifyMatchFailure( + op, "failed to extract slice from tensor.collapse_shape"); + rewriter.replaceOp(op, *result); + return success(); +} diff --git a/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/extract-slice-from-collapse-shape.mlir @@ -0,0 +1,123 @@ +// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-rewrite-extract-slice-from-collapse-shape %s | FileCheck %s + +func.func @extract_slice_static(%input: tensor<3x5x7x11xf32>) -> tensor<20x11xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [0, 0] [20, 11] [1, 1] : tensor<105x11xf32> to tensor<20x11xf32> + return %slice : tensor<20x11xf32> +} + +// CHECK: func.func @extract_slice_static(%[[arg0:.+]]: +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK-DAG: %[[init:.+]] = linalg.init_tensor [20, 11] : +// CHECK-DAG: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c0]] to %[[c20]] step %[[c1]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 11] [1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 11] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + +// ----- + + +func.func @extract_slice_static_strided(%input: tensor<3x5x7x11xf32>) -> tensor<10x5xf32> { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x5x7x11xf32> into tensor<105x11xf32> + %slice = tensor.extract_slice %collapsed [13, 0] [10, 5] [2, 2] : tensor<105x11xf32> to tensor<10x5xf32> + return %slice : tensor<10x5xf32> +} + +// CHECK: func.func @extract_slice_static_strided(%[[arg0:.+]]: +// CHECK-DAG: %[[c13:.+]] = arith.constant 13 : index +// CHECK-DAG: %[[c33:.+]] = arith.constant 33 : index +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK-DAG: %[[c5:.+]] = arith.constant 5 : index +// CHECK-DAG: %[[c7:.+]] = arith.constant 7 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [10, 5] : +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[c13]] to %[[c33]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %[[iv]](%[[c3]], %[[c5]], %[[c7]] +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] + + +// ----- + + +func.func @extract_slice_dynamic(%input: tensor<3x?x?x11xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3]] : tensor<3x?x?x11xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 5] [2, 2] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1 * 2)> +// CHECK: func.func @extract_slice_dynamic(%[[arg0:.+]]: tensor<{{.*}}>, %[[lb:.+]]: index, %[[sz:.+]]: index) +// CHECK-DAG: %[[c2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz]], 5] : tensor +// CHECK: %[[ub:.+]] = affine.apply #[[map1]]()[%[[lb]], %[[sz]]] +// CHECK: %[[tile:.+]] = scf.for %[[iv:.+]] = %[[lb]] to %[[ub]] step %[[c2]] iter_args(%[[iterArg:.+]] = %[[init]]) +// CHECK: %[[d1:.+]] = tensor.dim %arg0, %[[c1]] : tensor<3x?x?x11xf32> +// CHECK: %[[d2:.+]] = tensor.dim %arg0, %[[c2]] : tensor<3x?x?x11xf32> +// CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index %arg3(%[[c3]], %[[d1]], %[[d2]] : index, index, index) +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0] [1, 1, 1, 5] [1, 1, 1, 2] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg]][%[[iv]], 0] [1, 5] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: return %[[tile]] : + +// ----- + + +func.func @extract_slice_dynamic_multidim(%input: tensor<3x?x?x11x?xf32>, %offt0: index, %size0: index, %offt1: index, %size1: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x?xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt0, %offt1] [%size0, %size1] [1, 1] : tensor to tensor + return %slice : tensor +} + +// CHECK: #[[map1:.+]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: func.func @extract_slice_dynamic_multidim(%[[arg0:.+]]: tensor<3x?x?x11x?xf32>, %[[lb1:.+]]: index, %[[sz1:.+]]: index, %[[lb2:.+]]: index, %[[sz2:.+]]: index) +// CHECK: %[[c1:.+]] = arith.constant 1 : index +// CHECK: %[[c2:.+]] = arith.constant 2 : index +// CHECK: %[[c3:.+]] = arith.constant 3 : index +// CHECK: %[[c4:.+]] = arith.constant 4 : index +// CHECK: %[[c11:.+]] = arith.constant 11 : index +// CHECK: %[[init:.+]] = linalg.init_tensor [%[[sz1]], %[[sz2]]] : tensor +// CHECK: %[[ub1:.+]] = affine.apply #[[map1:.+]]()[%[[lb1]], %[[sz1]]] +// CHECK: %[[ub2:.+]] = affine.apply #[[map1:.+]]()[%[[lb2]], %[[sz2]]] +// CHECK: %[[tile1:.+]] = scf.for %[[iv1:.+]] = %[[lb1]] to %[[ub1]] step %[[c1]] iter_args(%[[iterArg1:.+]] = %[[init]]) +// CHECK: %[[d1:.+]] = tensor.dim %[[arg0]], %[[c1]] : +// CHECK: %[[d2:.+]] = tensor.dim %[[arg0]], %[[c2]] : +// CHECK: %[[multiIndex1:.+]]:3 = tensor.delinearize_index %[[iv1]](%[[c3]], %[[d1]], %[[d2]] : +// CHECK: %[[tile2:.+]] = scf.for %[[iv2:.+]] = %[[lb2]] to %[[ub2]] step %[[c1]] iter_args(%[[iterArg2:.+]] = %[[iterArg1]]) +// CHECK: %[[d4:.+]] = tensor.dim %[[arg0]], %[[c4]] : +// CHECK: %[[multiIndex2:.+]]:2 = tensor.delinearize_index %[[iv2]](%[[c11]], %[[d4]] : +// CHECK: %[[slice:.+]] = tensor.extract_slice %[[arg0]][%[[multiIndex1]]#0, %[[multiIndex1]]#1, %[[multiIndex1]]#2, %[[multiIndex2]]#0, %[[multiIndex2]]#1] [1, 1, 1, 1, 1] [1, 1, 1, 1, 1] : +// CHECK: %[[sliceFlat:.+]] = tensor.collapse_shape %[[slice]] {{\[}}[0, 1, 2], [3, 4]{{\]}} : +// CHECK: %[[update:.+]] = tensor.insert_slice %[[sliceFlat]] into %[[iterArg2]][%[[iv1]], %[[iv2]]] [1, 1] [1, 1] : +// CHECK: scf.yield %[[update]] : +// CHECK: scf.yield %[[tile2]] : +// CHECK: return %[[tile1]] : + +// ----- + +// Verifies that a linearized dimension that is not sliced does not generate a loop. + +// CHECK: @extract_slice_non_sliced_linearized_dim(%[[arg0:.+]]: tensor<{{.*}}>, +func.func @extract_slice_non_sliced_linearized_dim(%input: tensor<3x?x?x11x2xf32>, %offt: index, %size: index) -> tensor { + %collapsed = tensor.collapse_shape %input [[0, 1, 2], [3, 4]] : tensor<3x?x?x11x2xf32> into tensor + %slice = tensor.extract_slice %collapsed [%offt, 0] [%size, 22] [1, 1] : tensor to tensor + // CHECK: scf.for + // CHECK-NOT: scf.for + // CHECK: %[[multiIndex:.+]]:3 = tensor.delinearize_index + // CHECK: tensor.extract_slice %[[arg0]][%[[multiIndex]]#0, %[[multiIndex]]#1, %[[multiIndex]]#2, 0, 0] [1, 1, 1, 11, 2] [1, 1, 1, 1, 1] + return %slice : tensor +} diff --git a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt --- a/mlir/test/lib/Dialect/Tensor/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Tensor/CMakeLists.txt @@ -6,6 +6,7 @@ LINK_LIBS PUBLIC MLIRArithmeticDialect + MLIRLinalgDialect MLIRPass MLIRSCFDialect MLIRTensorDialect diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp --- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp +++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" @@ -28,7 +29,8 @@ TestTensorTransforms(const TestTensorTransforms &pass) : PassWrapper(pass) {} void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); + registry.insert(); } StringRef getArgument() const final { @@ -49,6 +51,12 @@ *this, "test-fold-constant-extract-slice", llvm::cl::desc("Test folding arith.constant and tensor.extract_slice"), llvm::cl::init(false)}; + + Option testRewriteExtractSliceWithTiledCollapseShape{ + *this, "test-rewrite-extract-slice-from-collapse-shape", + llvm::cl::desc("Test swapping tensor.extract_slice of a collapse_shape " + "with loop nest"), + llvm::cl::init(false)}; }; } // namespace @@ -74,12 +82,22 @@ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); } +static void applyRewriteExtractFromCollapseShapePatterns(Operation *rootOp) { + RewritePatternSet patterns(rootOp->getContext()); + auto options = tensor::SliceCollapseShapeOptions().setLoopType("for"); + patterns.add( + rootOp->getContext(), options); + (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns)); +} + void TestTensorTransforms::runOnOperation() { Operation *rootOp = getOperation(); if (testSplitPaddingPatterns) applySplitPaddingPatterns(rootOp); if (testFoldConstantExtractSlice) applyFoldConstantExtractSlicePatterns(rootOp); + if (testRewriteExtractSliceWithTiledCollapseShape) + applyRewriteExtractFromCollapseShapePatterns(rootOp); } namespace mlir {