Index: mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td =================================================================== --- mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -626,7 +626,6 @@ }]; } - def TileReductionUsingScfOp : Op { @@ -714,6 +713,89 @@ }]; } +def TileReductionUsingForeachThreadOp : + Op { + let description = [{ + Tile a PartialReductionOpInterface op to a tiled `scf.foreach_thread` doing + partial reduction. + + This transformation tiles the `target` along the reduction dimensions. It + creates a tensor initialized with the identity value. Then it creates a + `scf.foreach_thread` loops with the number threads given by `num_threads`. + The op is tiled op with a size equal to `floordiv(size, num_threads)`. + All the partial reduction value is are parallel inserted to create a new + tensor. After the loop a merge operation is created to do a final reduction + with the partial reductions tensor. + + #### Return modes + + This 3 returned handles point to: + - the fill op used to initialize the neutral element, + - the parallel tiled op and + - the result-combining op. + + #### Example: + + ``` + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.addf %arg7, %arg9 : f32 + linalg.yield %1 : f32 + } -> tensor + return %red : tensor + ``` + + is transformed into: + + ``` + %0 = tensor.empty(%dim_1) : tensor + %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor + %2 = scf.foreach_thread (%arg2) in (%c5) shared_outs(%arg3 = %1) -> (tensor) { + %4 = affine.min #map(%arg2)[%dim_0] + %5 = affine.max #map1(%4) + %extracted_slice = tensor.extract_slice %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor to tensor + %6 = affine.apply #map2(%arg2)[%dim_0] + %extracted_slice_2 = tensor.extract_slice %arg0[0, %6] [%dim, %5] [1, 1] : tensor to tensor + %extracted_slice_3 = tensor.extract_slice %extracted_slice[0] [%dim] [1] : tensor to tensor + %7 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%extracted_slice_2 : tensor) outs(%extracted_slice_3 : tensor) { + ^bb0(%in: f32, %out: f32): + %9 = arith.addf %in, %out : f32 + linalg.yield %9 : f32 + } -> tensor + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %7 into %arg3[0, %arg2] [%dim, 1] [1, 1] : tensor into tensor + } + } {thread_dim_mapping = []} + %3 = linalg.generic {indexing_maps = [#map3, #map4], iterator_types = ["parallel", "reduction"]} ins(%2 : tensor) outs(%arg1 : tensor) { + ^bb0(%in: f32, %out: f32): + %4 = arith.addf %in, %out : f32 + linalg.yield %4 : f32 + } -> tensor + ``` + }]; + + let arguments = (ins PDL_Operation:$target, + DefaultValuedAttr:$num_threads); + let results = (outs PDL_Operation:$fill_op, + PDL_Operation:$split_linalg_op, + PDL_Operation:$combining_linalg_op); + + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::linalg::LinalgOp target, + ::llvm::SmallVectorImpl<::mlir::Operation *> &results, + ::mlir::transform::TransformState &state); + }]; +} + def TileOp : Op, DeclareOpInterfaceMethods]> { Index: mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h =================================================================== --- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -445,6 +445,47 @@ ArrayRef tileSizes, Optional mapping); +/// Transformation information returned after reduction tiling. +struct ForeachThreadReductionTilingResult { + /// The partial reduction tiled op generated. + Operation *parallelTiledOp; + /// The final reduction operation merging all the partial reductions. + Operation *mergeOp; + /// The op initializing the tensor used for partial reductions. + Operation *initialOp; + /// The `scf.foreach_thread` operation that iterate over the tiles. + scf::ForeachThreadOp loops; +}; + +/// Method to tile a reduction to parallel iterations computing partial +/// reductions. After the loop all the partial reduction are merged into a final +/// reduction. For example for the following sequence +/// +/// ```mlir +/// %0 = linalg.generic %in ["parallel", "reduction"] +/// : tensor<7x9xf32> -> tensor<7xf32> +/// ``` +/// +/// into: +/// +/// ```mlir +/// %0 = linalg.fill ... : tensor<7x4xf32> +/// %1 = scf.foreach_thread (%iv) in (%c4) shared_outs(%arg0 = %0) +/// -> (tensor<7x4xf32>) { +/// %2 = tensor.extract_slice %arg3 : tensor<7x4xf32> to tensor<7xf32> +/// %3 = tensor.extract_slice %in : tensor<7x9xf32> -> tensor<7x?xf32> +/// %4 = linalg.generic %2, %3 ["parallel", "reduction"] +/// : tensor<7x?xf32> -> tensor<7xf32> +/// %5 = tensor.insert_slice %3, %arg0[0, %iv] : tensor<7x4xf32> +/// } +/// %6 = linalg.generic %1 ["parallel", "reduction"] +/// : tensor<7x4xf32> -> tensor<7xf32> +/// ``` +FailureOr +tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op, + ArrayRef numThreads, + Optional mapping); + /// All indices returned by IndexOp should be invariant with respect to /// tiling. Therefore, if an operation is tiled, we have to transform the /// indices accordingly, i.e. offset them by the values of the corresponding Index: mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp =================================================================== --- mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1165,6 +1165,39 @@ return DiagnosedSilenceableFailure(success()); } +//===----------------------------------------------------------------------===// +// TileReductionUsingForeachThreadOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::TileReductionUsingForeachThreadOp::applyToOne( + linalg::LinalgOp target, SmallVectorImpl &results, + transform::TransformState &state) { + SimpleRewriter rewriter(getContext()); + rewriter.setInsertionPoint(target); + SmallVector numThreads = extractFromI64ArrayAttr(getNumThreads()); + SmallVector numThreadResults; + for (int64_t num : numThreads) { + numThreadResults.push_back(rewriter.getIndexAttr(num)); + } + + FailureOr result = + linalg::tileReductionUsingForeachThread( + rewriter, cast(target.getOperation()), + numThreadResults, /*mapping=*/llvm::None); + + if (failed(result)) { + results.assign(3, nullptr); + Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark); + diag << "could not tile reduction in target."; + return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); + } + results.push_back(result->initialOp); + results.push_back(result->parallelTiledOp); + results.push_back(result->mergeOp); + return DiagnosedSilenceableFailure(success()); +} + //===----------------------------------------------------------------------===// // TileOp //===----------------------------------------------------------------------===// Index: mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -211,58 +211,21 @@ vals); } -/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The -/// tiling is specified by the number of tiles/threads `numThreads` and the -/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is -/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], -/// numThreads[i])`. If non-empty, the `mapping` is added as an -/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate -/// that the dimension is not tiled, and can be thought of as tiling by the full -/// size of data. -/// It is the user's responsibility to ensure that `numThreads` is a valid -/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the -/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will -/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. -static FailureOr tileToForeachThreadOpImpl( - RewriterBase &b, TilingInterface op, ArrayRef numThreads, +/// Fill out the `tiledOffsets` and `tiledSizes` to be used to tile to a given +/// number of threads. +static void calculateTileOffsetsAndSizes( + RewriterBase &b, Location loc, scf::ForeachThreadOp foreachThreadOp, + ArrayRef numThreads, SmallVector loopRanges, + bool omitTileOffsetBoundsCheck, Optional> nominalTileSizes, - Optional mapping, bool omitTileOffsetBoundsCheck) { - Location loc = op->getLoc(); - OpBuilder::InsertionGuard g(b); - SmallVector loopRanges = op.getIterationDomain(b); - if (loopRanges.empty()) - return op->emitOpError("expected non-empty loop ranges"); - auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; - if (llvm::any_of(loopRanges, hasStrideOne)) - return op->emitOpError("only stride-1 supported atm"); - - // Gather destination tensors. - SmallVector dest; - if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) - return op->emitOpError("failed to get destination tensors"); - + SmallVector &tiledOffsets, + SmallVector &tiledSizes) { + ValueRange threadIds = foreachThreadOp.getThreadIndices(); SmallVector nonZeroNumThreads = llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { return !isConstantIntValue(ofr, 0); })); - SmallVector materializedNonZeroNumThreads = - llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { - return getValueOrCreateConstantIndexOp(b, loc, ofr); - })); - - Operation *tiledOp = nullptr; - - // Create the ForeachThreadOp. We don't use the lambda body-builder - // version because we require the use of RewriterBase in the body, so we - // manually move the insertion point to the body below. - scf::ForeachThreadOp foreachThreadOp = b.create( - loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); - - // Fill out the ForeachThreadOp body. - b.setInsertionPointToStart(foreachThreadOp.getBody(0)); - ValueRange threadIds = foreachThreadOp.getThreadIndices(); int64_t nLoops = loopRanges.size(); - SmallVector tiledOffsets, tiledSizes; tiledOffsets.reserve(nLoops); tiledSizes.reserve(nLoops); for (unsigned loopIdx = 0, threadIdIdx = 0; loopIdx < nLoops; ++loopIdx) { @@ -316,6 +279,61 @@ tiledSizes.push_back(tileSizePerThread); ++threadIdIdx; } +} + +/// Rewrite a TilingInterface `op` to a tiled `scf.foreach_thread`. The +/// tiling is specified by the number of tiles/threads `numThreads` and the +/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is +/// not specified, then it is derived from `numThreads` as `ceilDiv(dimSize[i], +/// numThreads[i])`. If non-empty, the `mapping` is added as an +/// attribute to the resulting `scf.foreach_thread`. A zero tile sizes indicate +/// that the dimension is not tiled, and can be thought of as tiling by the full +/// size of data. +/// It is the user's responsibility to ensure that `numThreads` is a valid +/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the +/// Linalg case). If `omitTileOffsetBoundsCheck` is true, then the function will +/// assume that `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds. +static FailureOr tileToForeachThreadOpImpl( + RewriterBase &b, TilingInterface op, ArrayRef numThreads, + Optional> nominalTileSizes, + Optional mapping, bool omitTileOffsetBoundsCheck) { + Location loc = op->getLoc(); + OpBuilder::InsertionGuard g(b); + SmallVector loopRanges = op.getIterationDomain(b); + if (loopRanges.empty()) + return op->emitOpError("expected non-empty loop ranges"); + auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); }; + if (llvm::any_of(loopRanges, hasStrideOne)) + return op->emitOpError("only stride-1 supported atm"); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return op->emitOpError("failed to get destination tensors"); + + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(b, loc, ofr); + })); + + Operation *tiledOp = nullptr; + + // Create the ForeachThreadOp. We don't use the lambda body-builder + // version because we require the use of RewriterBase in the body, so we + // manually move the insertion point to the body below. + scf::ForeachThreadOp foreachThreadOp = b.create( + loc, dest, ValueRange(materializedNonZeroNumThreads), mapping); + + // Fill out the ForeachThreadOp body. + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + SmallVector tiledOffsets, tiledSizes; + calculateTileOffsetsAndSizes(b, loc, foreachThreadOp, numThreads, loopRanges, + omitTileOffsetBoundsCheck, nominalTileSizes, + tiledOffsets, tiledSizes); // Clone the tileable op and update its destination operands to use the output // bbArgs of the ForeachThreadOp. @@ -392,6 +410,140 @@ /*omitTileOffsetBoundsCheck=*/true); } +FailureOr +linalg::tileReductionUsingForeachThread(RewriterBase &b, + PartialReductionOpInterface op, + ArrayRef numThreads, + Optional mapping) { + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(b); + // Ops implementing PartialReductionOpInterface are expected to implement + // TilingInterface. + auto tilingInterfaceOp = cast(op.getOperation()); + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + SmallVector redDims; + cast(op.getOperation()).getReductionDims(redDims); + if (redDims.size() != 1) + return b.notifyMatchFailure( + op, "only support ops with one reduction dimension."); + int reductionDim = static_cast(redDims.front()); + // 1. create the inital tensor value. + FailureOr identityTensor = + op.generateInitialTensorForPartialReduction(b, loc, numThreads, + reductionDim); + if (failed(identityTensor)) + return b.notifyMatchFailure(op, + "cannot create a tensor of identity value."); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return b.notifyMatchFailure(op, "failed to get destination tensors"); + + Operation *tiledOp = nullptr; + + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { + return getValueOrCreateConstantIndexOp(b, loc, ofr); + })); + + // 2. Create the ForeachThreadOp with an empty region. + scf::ForeachThreadOp foreachThreadOp = b.create( + loc, identityTensor.value()->getResults(), + ValueRange(materializedNonZeroNumThreads), mapping); + + // 3. calculate the tile offsets and sizes. + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + SmallVector tiledOffsets, tiledSizes; + calculateTileOffsetsAndSizes( + b, loc, foreachThreadOp, numThreads, iterationDomain, + /*omitTileOffsetBoundsCheck =*/false, + /*nominalTileSizes=*/llvm::None, tiledOffsets, tiledSizes); + + // 4. Clone the tileable op and update its destination operands to use the + // output bbArgs of the ForeachThreadOp. + ArrayRef destBbArgs = + foreachThreadOp.getOutputBlockArguments(); + Operation *clonedOp = b.clone(*op.getOperation()); + auto destinationStyleOp = cast(clonedOp); + for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) { + auto *it = llvm::find(dest, outOperand->get()); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + SmallVector strides(numThreads.size(), b.getIndexAttr(1)); + SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); + SmallVector sizes = tiledSizes; + sizes[reductionDim] = b.getIndexAttr(1); + outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); + // TODO: use SubsetExtractOpInterface once it is available. + Value patial = b.create( + loc, outOperand->get().getType().cast(), + destBbArgs[destNum], outOffsets, sizes, strides); + outOperand->set(patial); + } + + // 5. Tile the cloned op and delete the clone. + SmallVector tiledOps = + cast(clonedOp).getTiledImplementation(b, tiledOffsets, + tiledSizes); + b.eraseOp(clonedOp); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + + // 6. Insert the partial reductions back into a new tensor. + auto tiledInterfaceOp = dyn_cast(tiledOp); + assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface"); + OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); + for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), + tiledInterfaceOp->getResults(), destBbArgs)) { + b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); + SmallVector resultOffsets, resultSizes; + if (failed(tilingInterfaceOp.getResultTilePosition( + b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets, + resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector resultOffsetsRank, resultSizesRank; + int64_t offIdx = 0; + int64_t sizeIdx = 0; + for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { + if (i == reductionDim) { + resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front()); + resultSizesRank.push_back(b.getIndexAttr(1)); + continue; + } + resultOffsetsRank.push_back(resultOffsets[offIdx++]); + resultSizesRank.push_back(resultSizes[sizeIdx++]); + } + + SmallVector strides(resultSizesRank.size(), + b.getIndexAttr(1)); + b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); + b.create(loc, std::get<1>(it), + std::get<2>(it), resultOffsetsRank, + resultSizesRank, strides); + } + // 7. Merge the partial reductions. + b.setInsertionPointAfter(foreachThreadOp); + Operation *mergeOp = + op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim); + b.replaceOp(op, mergeOp->getResults()); + ForeachThreadReductionTilingResult results; + results.initialOp = identityTensor.value(); + results.loops = foreachThreadOp; + results.parallelTiledOp = tiledOp; + results.mergeOp = mergeOp; + return results; +} + // Insert a tile `source` into the destination tensor `dest`. The position at // which the tile is inserted (as well as size of tile) is taken from a given // ExtractSliceOp `sliceOp`. Index: mlir/test/Dialect/Linalg/transform-tile-reduction.mlir =================================================================== --- mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -canonicalize -cse | FileCheck %s func.func @reduction_tile(%arg0: tensor, %out: tensor) -> tensor { %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, @@ -86,3 +86,114 @@ // CHECK: } // CHECK: linalg.generic // CHECK: return + +// ----- + +func.func @reduction_tile_parallel( + %arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5] } +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: func @reduction_tile_parallel(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] +// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) +// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor +// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor into tensor +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor + +// ----- + +func.func @matmul_tile_parallel( + %A: tensor, %B: tensor, %out: tensor) -> tensor { + %matmul = linalg.matmul ins(%A, %B: tensor, tensor) + outs(%out: tensor) -> tensor + return %matmul : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 0, 5] } +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-(d0 * (s0 ceildiv 5)) + s0, s0 ceildiv 5)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (0, d0)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 5))> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)> +// CHECK: func @matmul_tile_parallel(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor, %[[ARG2:.+]]: tensor +// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C1]] : tensor +// CHECK-DAG: %[[D3:.*]] = tensor.dim %[[ARG2]], %[[C0]] : tensor +// CHECK-DAG: %[[D4:.*]] = tensor.dim %[[ARG2]], %[[C1]] : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] +// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) +// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor +// CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] +// CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor +// CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor +// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0, 0] [%[[D0]], %[[D2]]] [1, 1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.matmul ins(%[[INCHUNKA]], %[[INCHUNKB]] : tensor, tensor) outs(%[[TEMPEXT]] : tensor) -> tensor +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG3]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor into tensor +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP3]], #[[MAP4]]], iterator_types = ["parallel", "parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG2]] : tensor) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor