diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -20,6 +20,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/Dialect/X86Vector/Transforms.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" @@ -134,9 +135,10 @@ /// /// Note that there is no simplification other than constant propagation applied /// to slice extraction and insertion. -std::pair splitOp(RewriterBase &rewriter, LinalgOp op, - unsigned dimension, - OpFoldResult splitPoint); +std::pair splitOp(RewriterBase &rewriter, + TilingInterface op, + unsigned dimension, + OpFoldResult splitPoint); /// Perform standalone tiling of a single LinalgOp by `tileSizes`. /// and permute the loop nest according to `interchangeVector` diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -177,12 +177,6 @@ bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer, Value consumedView, LinalgOp producer); -/// Creates either a memref.subview or a tensor.extract_slice with the given -/// offsets/sizes/strides based on the type of `value`. -Value createSlice(OpBuilder &builder, Location loc, Value value, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides); - /// Computes tile offsets, given a list of loop `ivs` and `tileSizes`. In case a /// tile size is zero (i.e., no tiling), the corresponding offset is also zero. SmallVector computeTileOffsets(OpBuilder &b, Location loc, diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td --- a/mlir/include/mlir/Interfaces/TilingInterface.td +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -76,8 +76,11 @@ operation is to be inserted into. The type of the `dest` Values is same as the types returned by `getDestinationOperands` method. - - `offsets` provides the offset of the tile within the - iteration space + - `offsets` provides the offset of the tile in the coordinate system + of the original iteration space, i.e., if an iteration space + dimension had non-zero offset, it must be included in the offset + provided here (as opposed to zero-based offset "relative" to the + iteration space). - `sizes` provides the size of the tile. - `tileDestOperands` specifies whether to also tile `dest` operands or not. Avoiding tiling `dest` operands can be useful for @@ -141,8 +144,11 @@ operation is to be inserted into. The type of the `dest` Values is same as the types returned by `getDestinationOperands` method. - - `offsets` provides the offset of the tile within the - iteration space + - `offsets` provides the offset of the tile in the coordinate system + of the original iteration space, i.e., if an iteration space + dimension had non-zero offset, it must be included in the offset + provided here (as opposed to zero-based offset "relative" to the + iteration space). - `sizes` provides the size of the tile. - `tileDestOperands` specifies whether to also tile `dest` operands or not. Avoiding tiling `dest` operands can be useful for diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -740,8 +740,9 @@ } rewriter.setInsertionPoint(linalgOp); - std::tie(first.emplace_back(), second.emplace_back()) = - linalg::splitOp(rewriter, linalgOp, getDimension(), std::get<1>(pair)); + std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( + rewriter, cast(linalgOp.getOperation()), + getDimension(), std::get<1>(pair)); } results.set(getFirst().cast(), first); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Split.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Split.cpp @@ -8,147 +8,124 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/TilingInterface.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::linalg; -/// Extract the slices of `operands` supplied to the given operation `op` such -/// that they are sufficient to execute the op for the subset of its iteration -/// space defined by `splitIterationSpace`. The subset is a part of the original -/// iteration space split at the given `dimension`. If `offset` is provided, it -/// indicates the iterator value at which the dimension has been split and -/// requires the "high" part starting at the given offset of the operands to be -/// generated; otherwise, the "low" part with no offset is generated. Note that -/// `operands` are not necessarily the actual operands of `op`. -static SmallVector -getOperandSlices(RewriterBase &b, Location loc, LinalgOp op, - ValueRange splitIterationSpace, ValueRange operands, - unsigned dimension, Value offset = nullptr) { - SmallVector slices; - slices.reserve(op.getNumInputsAndOutputs()); - for (OpOperand *opOperand : op.getInputAndOutputOperands()) { - auto type = opOperand->get().getType().dyn_cast(); - AffineMap indexing = op.getTiedIndexingMap(opOperand); - - // If the type is not sliceable, or the slice is requested along the - // dimension that is not used in indexing this type, just use the entire - // operand. - if (!type || dimension >= indexing.getNumDims() || - !indexing.isFunctionOfDim(dimension)) { - slices.push_back(opOperand->get()); - continue; - } - - SmallVector sizes; - sizes.reserve(indexing.getNumResults()); - for (AffineExpr dimIndexing : indexing.getResults()) { - sizes.push_back(makeComposedFoldedAffineApply( - b, loc, dimIndexing, - getAsOpFoldResult(llvm::to_vector(splitIterationSpace)))); - } - SmallVector offsets(type.getRank(), b.getIndexAttr(0)); - SmallVector strides(type.getRank(), b.getIndexAttr(1)); - - if (offset) { - offsets[dimension] = offset; - offsets = applyMapToValues(b, loc, indexing, offsets); - } - - slices.push_back(createSlice(b, loc, - operands[opOperand->getOperandNumber()], - offsets, sizes, strides)); - } - - return slices; -} - /// Creates a part of the given `op` split along the iteration space `dimension` /// with the given `size` and an optional `offset` (default 0). Makes slices /// of operands, using the input operands of the original op and the output -/// operands provided as `resultOperands`. Expects `splitIterationSpace` to be -/// a list of values representing the shape of the iteration space of the -/// original op and updates it to be the iteration space of the curent part. -/// Returns the split-out op as well as the output operand values updated with -/// the partial results produced by this op through `results`. -static LinalgOp -createSplitPart(RewriterBase &b, Location loc, LinalgOp op, - ValueRange resultOperands, - llvm::MutableArrayRef splitIterationSpace, - unsigned dimension, OpFoldResult size, - SmallVectorImpl &results, Value offset = nullptr) { - ImplicitLocOpBuilder implicit(op.getLoc(), b); - splitIterationSpace[dimension] = materializeOpFoldResult(implicit, size); - SmallVector operands = llvm::to_vector( - llvm::map_range(op.getInputOperands(), - [](OpOperand *opOperand) { return opOperand->get(); })); - llvm::append_range(operands, resultOperands); - operands = getOperandSlices(b, loc, op, splitIterationSpace, operands, - dimension, offset); - Operation *part = - op.clone(b, loc, getTensorOutputTypes(op, operands), operands); - results = insertSlicesBack(b, loc, op, operands, part->getResults()); - return cast(part); +/// operands provided as `resultOperands`. Expects `offsets` and `sizes` to +/// define the shape of the iteration space of the original op. Returns the +/// split-out op as well as the output operand values updated with the partial +/// results produced by this op through `results`. +static TilingInterface +createSplitPart(RewriterBase &b, Location loc, TilingInterface op, + ArrayRef offsets, ArrayRef sizes, + ValueRange resultOperands, unsigned dimension, + OpFoldResult size, OpFoldResult offset, + SmallVectorImpl &results) { + // Iteration space of the current part. + SmallVector sizesCopy = llvm::to_vector(sizes); + SmallVector offsetsCopy = llvm::to_vector(offsets); + sizesCopy[dimension] = size; + offsetsCopy[dimension] = offset; + + // Create the part as it it were a single tile. + SmallVector tiled = + op.getTiledImplementation(b, resultOperands, offsetsCopy, sizesCopy, + /*tileDestOperands=*/true); + assert(tiled.size() == 1 && "expected a single result from tiling"); + auto part = cast(tiled.front()); + + // Insert the results back and populate the `results` list. + for (auto i : llvm::seq(0, part->getNumResults())) { + SmallVector resultOffsets, resultSizes; + if (failed(op.getResultTilePosition(b, i, offsetsCopy, sizesCopy, + resultOffsets, resultSizes))) + return nullptr; + SmallVector resultStrides(resultOffsets.size(), + b.getIndexAttr(1)); + Value inserted = b.create( + loc, part->getResult(i), resultOperands[i], resultOffsets, resultSizes, + resultStrides); + results.push_back(inserted); + } + + return part; } -std::pair linalg::splitOp(RewriterBase &rewriter, - LinalgOp op, unsigned dimension, - OpFoldResult splitPoint) { +std::pair +linalg::splitOp(RewriterBase &rewriter, TilingInterface op, unsigned dimension, + OpFoldResult splitPoint) { + // Compute the iteration space. + SmallVector iterationSpace = op.getIterationDomain(rewriter); + // Bail out on dimension overflow. - if (dimension >= op.getNumLoops()) - return std::make_pair(op, LinalgOp()); - - // Compute the iteration space size as values. - SmallVector allShapes = - op.createFlatListOfOperandDims(rewriter, op.getLoc()); - AffineMap shapesToLoops = op.getShapesToLoopsMap(); - SmallVector iterationSpaceShapes = - applyMapToValues(rewriter, op.getLoc(), shapesToLoops, allShapes); - - // Update the iteration space to have `splitPoint` as the size of `dimension` - // and use it to slice operands and results for a new, smaller instance of the - // `op`. Adjust the size if necessary to prevent overflows. Insert the partial - // results back. - OpFoldResult dimSize = getAsOpFoldResult(iterationSpaceShapes[dimension]); + if (dimension >= iterationSpace.size()) + return std::make_pair(op, TilingInterface()); + + SmallVector offsets = + getAsOpFoldResult(llvm::to_vector(llvm::map_range( + iterationSpace, [](const Range &range) { return range.offset; }))); + SmallVector sizes = + getAsOpFoldResult(llvm::to_vector(llvm::map_range( + iterationSpace, [](const Range &range) { return range.size; }))); + + // Adjust the split point so that it doesn't overflow the size. + AffineExpr d0, d1, d2; + bindDims(rewriter.getContext(), d0, d1, d2); OpFoldResult minSplitPoint = makeComposedFoldedAffineMin( - rewriter, op->getLoc(), - AffineMap::getMultiDimIdentityMap(/*numDims=*/2, rewriter.getContext()), - {splitPoint, dimSize}); - SmallVector splitIterationSpace = - llvm::to_vector(iterationSpaceShapes); - SmallVector originalResults = llvm::to_vector( - llvm::map_range(op.getOutputOperands(), - [](OpOperand *opOperand) { return opOperand->get(); })); - SmallVector firstResults; - LinalgOp first = createSplitPart(rewriter, op.getLoc(), op, originalResults, - splitIterationSpace, dimension, - minSplitPoint, firstResults); - - // Update the iteration space to cover the remaining part of the original - // space, then create another instance of the `op` in that space. The size of - // the remaining part may become zero, but is never negative because of the - // adjustment above. - AffineExpr d0 = rewriter.getAffineDimExpr(0); - AffineExpr d1 = rewriter.getAffineDimExpr(1); + rewriter, op.getLoc(), + AffineMap::inferFromExprList(ArrayRef{d0, d1 + d2}).front(), + {splitPoint, offsets[dimension], sizes[dimension]}); + + // Compute the size of the second part. Return early if the second part would + // have an empty iteration space. OpFoldResult remainingSize = makeComposedFoldedAffineApply( - rewriter, op.getLoc(), d0 - d1, {dimSize, minSplitPoint}); + rewriter, op.getLoc(), d0 + d1 - d2, + {iterationSpace[dimension].offset, iterationSpace[dimension].size, + minSplitPoint}); + if (auto attr = remainingSize.dyn_cast()) { + if (attr.cast().getValue().isZero()) + return {op, TilingInterface()}; + } + + // Create the first part. + SmallVector firstResults; + TilingInterface firstPart = createSplitPart( + rewriter, op.getLoc(), op, offsets, sizes, + op.getDestinationOperands(rewriter), dimension, minSplitPoint, + getAsOpFoldResult(iterationSpace[dimension].offset), firstResults); + + // Need to pretend that the original op now takes as operands firstResults, + // otherwise tiling interface implementation will take the wrong value to + // produce data tiles. + rewriter.updateRootInPlace(op, [&]() { + unsigned numTotalOperands = op->getNumOperands(); + unsigned numOutputOperands = firstResults.size(); + op->setOperands(numTotalOperands - numOutputOperands, numOutputOperands, + firstResults); + }); + + // Create the second part. + OpFoldResult totalOffset = makeComposedFoldedAffineApply( + rewriter, op.getLoc(), d0 + d1, {offsets[dimension], minSplitPoint}); SmallVector secondResults; - ImplicitLocOpBuilder implicit(op.getLoc(), rewriter); - Value splitPointValue = materializeOpFoldResult(implicit, minSplitPoint); - LinalgOp second = createSplitPart( - rewriter, op.getLoc(), op, firstResults, splitIterationSpace, dimension, - remainingSize, secondResults, splitPointValue); - - // Fixup the linalg.index results in the second part. - SmallVector ivAdditions; - ivAdditions.resize(splitIterationSpace.size()); - ivAdditions[dimension] = splitPointValue; - linalg::offsetIndices(rewriter, cast(second), ivAdditions); + TilingInterface secondPart = + createSplitPart(rewriter, op.getLoc(), op, offsets, sizes, firstResults, + dimension, remainingSize, totalOffset, secondResults); // Replace the original op with the results of the two newly created ops. rewriter.replaceOp(op, secondResults); - return std::make_pair(first, second); + return {firstPart, secondPart}; } diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -913,21 +913,6 @@ return sliceOp->getResult(0); } -Value createSlice(OpBuilder &builder, Location loc, Value value, - ArrayRef offsets, ArrayRef sizes, - ArrayRef strides) { - if (value.getType().isa()) { - return builder.create(loc, value, offsets, sizes, - strides); - } - - // This intentionally does not attempt to compose the extractslice operations. - assert(value.getType().isa() && - "expected a ranked tensor type"); - return builder.create(loc, value, offsets, sizes, - strides); -} - SmallVector computeTileOffsets(OpBuilder &b, Location loc, ValueRange ivs, ValueRange tileSizes) { SmallVector offsets; diff --git a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir --- a/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir +++ b/mlir/test/Dialect/Linalg/multisize-tiling-full.mlir @@ -49,18 +49,17 @@ // CHECK: %[[SLICE_1:.+]] = tensor.extract_slice %[[OUT]][0, 0] [4, 34] [1, 1] // CHECK: scf.for %[[I1:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_1:.+]] = %[[SLICE_1]]) - // CHECK: %[[INSLICE_1:.+]] = tensor.extract_slice %[[IN]][%[[I1]], 0] [2, 34] [1, 1] // CHECK: %[[OUTSLICE_1:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 34] [1, 1] - // CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] + // CHECK: %[[SLICE_2:.+]] = tensor.extract_slice %[[ITERARG_1]][%[[I1]], 0] [2, 16] [1, 1] // CHECK: %[[LOOPRES:.+]] = scf.for %[[I2:.+]] = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ITERARG_2:.+]] = %[[SLICE_2]]) - // CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[INSLICE_1]][0, %[[I2]]] [2, 8] [1, 1] + // CHECK: %[[INSLICE_2:.+]] = tensor.extract_slice %[[IN]][%[[I1]], %[[I2]]] [2, 8] [1, 1] // CHECK: %[[OUTSLICE_2:.+]] = tensor.extract_slice %[[ITERARG_2]][0, %[[I2]]] [2, 8] [1, 1] // CHECK: %[[RESSLICE_1:.+]] = linalg.generic {{.*}} ins(%[[INSLICE_2]] : tensor<2x8xf32>) outs(%[[OUTSLICE_2]] : tensor<2x8xf32>) // CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]] // CHECK: scf.yield %[[RESPARTIAL]] - // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1] + // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1] // CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1] // CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]]) // CHECK-COUNT-2: tensor.extract_slice @@ -74,7 +73,7 @@ // CHECK: tensor.insert_slice // CHECK: tensor.extract_slice // CHECK: scf.for - // CHECK-COUNT-3: tensor.extract_slice + // CHECK-COUNT-2: tensor.extract_slice // CHECK: scf.for // CHECK-COUNT-2: tensor.extract_slice // CHECK: linalg.generic {{.*}} ins(%{{.*}} : tensor<3x8xf32>) diff --git a/mlir/test/Dialect/Linalg/transform-op-split.mlir b/mlir/test/Dialect/Linalg/transform-op-split.mlir --- a/mlir/test/Dialect/Linalg/transform-op-split.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-split.mlir @@ -1,5 +1,4 @@ // RUN: mlir-opt %s --test-transform-dialect-interpreter --split-input-file -verify-diagnostics | FileCheck %s -// RUN: mlir-opt %s --test-transform-dialect-interpreter --canonicalize --split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CANON transform.with_pdl_patterns { ^bb0(%arg0: !pdl.operation): @@ -13,7 +12,6 @@ func.func private @elem(%arg0: f32, %arg1: index, %arg2: index) -> f32 // CHECK: #[[$ADD_42_MAP:.+]] = affine_map<(d0) -> (d0 + 42)> -// CHECK: #[[$ADD_10_MAP:.+]] = affine_map<(d0) -> (d0 + 10)> // CHECK-LABEL: @one_d_static // CHECK-SAME: %[[IN:.+]]: tensor<100xf32>, %[[OUT:.+]]: tensor<100xf32> @@ -53,37 +51,14 @@ // CHECK-LABEL: @one_d_static_overflow // CHECK-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> -// CANON-LABEL: @one_d_static_overflow -// CANON-SAME: %[[IN:.+]]: tensor<10xf32>, %[[OUT:.+]]: tensor<10xf32> func.func @one_d_static_overflow(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> { - // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> - // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT]][0] [10] [1] : tensor<10xf32> to tensor<10xf32> + // Folding is sufficiently powerful to detect the static overflow and avoid + // the splitting altogether. // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic - // CHECK: ins(%[[IN_SLICE_LOW]] - // CHECK: outs(%[[OUT_SLICE_LOW]] + // CHECK: ins(%[[IN]] + // CHECK: outs(%[[OUT]] // CHECK: linalg.index 0 // CHECK: func.call @elem - // CHECK: %[[RES_PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [10] [1] - // - // Due to overflow, the first part of the split computes everything and the - // insert/extract slices are folded away by the canonicalizer. - // CANON: %[[RES_PARTIAL:.+]] = linalg.generic - // CANON: ins(%[[IN]] - // CANON: outs(%[[OUT]] - // CANON: linalg.index 0 - // CANON: func.call @elem - // The second part operates on zero-sized slices that are not currently - // folded away. - // - // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> - // CHECK: %[[OUT_SLICE_HIGH:.+]] = tensor.extract_slice %[[RES_PARTIAL]][10] [0] [1] : tensor<10xf32> to tensor<0xf32> - // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic - // CHECK: ins(%[[IN_SLICE_HIGH]] - // CHECK: outs(%[[OUT_SLICE_HIGH]] - // CHECK: %[[IDX:.+]] = linalg.index 0 - // CHECK: affine.apply #[[$ADD_10_MAP]](%[[IDX]]) - // CHECK: func.call @elem - // CHECK: %[[RES:.+]] = tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[RES_PARTIAL]][10] [0] [1] %0 = linalg.generic { indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], iterator_types = ["parallel"] @@ -118,6 +93,7 @@ func.func @dynamic(%arg0: tensor<100xf32>, %arg1: tensor<100xf32>) -> tensor<100xf32> { // CHECK: %[[SPLIT:.+]] = call @get_size // CHECK: %[[SPLIT_LOW:.+]] = affine.min #[[$MAP_MIN_100]]()[%[[SPLIT]] + // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] // CHECK: %[[IN_SLICE_LOW:.+]] = tensor.extract_slice %[[IN:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor // CHECK: %[[OUT_SLICE_LOW:.+]] = tensor.extract_slice %[[OUT:.+]][0] [%[[SPLIT_LOW]]] [1] : tensor<100xf32> to tensor // CHECK: %[[RES_SLICE_LOW:.+]] = linalg.generic @@ -125,7 +101,6 @@ // CHECK: outs(%[[OUT_SLICE_LOW]] // CHECK: %[[PARTIAL:.+]] = tensor.insert_slice %[[RES_SLICE_LOW]] into %[[OUT]][0] [%[[SPLIT_LOW]]] [1] // - // CHECK: %[[SPLIT_HIGH_1:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] // CHECK: %[[SPLIT_HIGH_2:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] // CHECK: %[[IN_SLICE_HIGH:.+]] = tensor.extract_slice %[[IN:.+]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_2]]] [1] : tensor<100xf32> to tensor // CHECK: %[[SPLIT_HIGH_3:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] @@ -133,7 +108,8 @@ // CHECK: %[[RES_SLICE_HIGH:.+]] = linalg.generic // CHECK: ins(%[[IN_SLICE_HIGH]] // CHECK: outs(%[[OUT_SLICE_HIGH]] - // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_3]]] [1] + // CHECK: %[[SPLIT_HIGH_4:.+]] = affine.apply #[[$MAP_S_MINUS_100]]()[%[[SPLIT_LOW]]] + // CHECK: tensor.insert_slice %[[RES_SLICE_HIGH]] into %[[PARTIAL]][%[[SPLIT_LOW]]] [%[[SPLIT_HIGH_4]]] [1] %0 = func.call @get_size() : () -> index %1 = linalg.generic { indexing_maps = [affine_map<(i) -> (i)>, affine_map<(i) -> (i)>], @@ -175,14 +151,16 @@ // // CHECK: %[[IN_2:.+]] = tensor.extract_slice %[[IN]] // CHECK: %[[OUT_2:.+]] = tensor.extract_slice %[[PARTIAL_1]] - // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN_2]] - // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[OUT_2]] + // Note that `extract_slice` taking a slice from another `extract_slice` result + // is folded to use the operand of the first `extract_slice`. + // CHECK: %[[IN_21:.+]] = tensor.extract_slice %[[IN]] + // CHECK: %[[OUT_21:.+]] = tensor.extract_slice %[[PARTIAL_1]] // CHECK: %[[RES_21:.+]] = linalg.generic // CHECK-SAME: ins(%[[IN_21]] : tensor<6x16xf32>) // CHECK-SAME: outs(%[[OUT_21]] : tensor<6x16xf32>) // CHECK: %[[PARTIAL_21:.+]] = tensor.insert_slice %[[RES_21]] into %[[OUT_2]] // - // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN_2]] + // CHECK: %[[IN_22:.+]] = tensor.extract_slice %[[IN]] // CHECK: %[[OUT_22:.+]] = tensor.extract_slice %[[PARTIAL_21]] // CHECK: %[[RES_22:.+]] = linalg.generic // CHECK-SAME: ins(%[[IN_22]] : tensor<6x18xf32>)