diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -751,6 +751,8 @@ All the partial reduction value is are parallel inserted to create a new tensor. After the loop a merge operation is created to do a final reduction with the partial reductions tensor. + If an extra `tile_sizes` parameter is passed the tiles are cyclically + distributed on the threads of the `scf.foreach_threads` loop. #### Return modes @@ -804,7 +806,8 @@ }]; let arguments = (ins PDL_Operation:$target, - DefaultValuedAttr:$num_threads); + DefaultValuedAttr:$num_threads, + DefaultValuedAttr:$tile_sizes); let results = (outs PDL_Operation:$fill_op, PDL_Operation:$split_linalg_op, PDL_Operation:$combining_linalg_op); diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -496,7 +496,8 @@ FailureOr tileReductionUsingForeachThread(RewriterBase &b, PartialReductionOpInterface op, ArrayRef numThreads, - Optional mapping); + ArrayRef tileSizes = {}, + Optional mapping = llvm::None); /// All indices returned by IndexOp should be invariant with respect to /// tiling. Therefore, if an operation is tiled, we have to transform the diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1217,16 +1217,12 @@ transform::TransformState &state) { TrivialPatternRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); - SmallVector numThreads = extractFromI64ArrayAttr(getNumThreads()); - SmallVector numThreadResults; - for (int64_t num : numThreads) { - numThreadResults.push_back(rewriter.getIndexAttr(num)); - } - + SmallVector numThreads = getAsOpFoldResult(getNumThreads()); + SmallVector tileSizes = getAsOpFoldResult(getTileSizes()); FailureOr result = linalg::tileReductionUsingForeachThread( rewriter, cast(target.getOperation()), - numThreadResults, /*mapping=*/std::nullopt); + numThreads, tileSizes, /*mapping=*/std::nullopt); if (failed(result)) { results.assign(3, nullptr); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -410,152 +411,6 @@ /*omitTileOffsetBoundsCheck=*/true); } -FailureOr -linalg::tileReductionUsingForeachThread(RewriterBase &b, - PartialReductionOpInterface op, - ArrayRef numThreads, - Optional mapping) { - Location loc = op.getLoc(); - OpBuilder::InsertionGuard g(b); - // Ops implementing PartialReductionOpInterface are expected to implement - // TilingInterface. - auto tilingInterfaceOp = cast(op.getOperation()); - SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); - if (op->getNumResults() != 1) - return b.notifyMatchFailure( - op, "don't support ops with multiple results for now"); - SmallVector iterators = - tilingInterfaceOp.getLoopIteratorTypes(); - SmallVector redDims; - cast(op.getOperation()).getReductionDims(redDims); - if (redDims.size() != 1) - return b.notifyMatchFailure( - op, "only support ops with one reduction dimension."); - int reductionDim = static_cast(redDims.front()); - // 1. create the inital tensor value. - FailureOr identityTensor = - op.generateInitialTensorForPartialReduction(b, loc, numThreads, - reductionDim); - if (failed(identityTensor)) - return b.notifyMatchFailure(op, - "cannot create a tensor of identity value."); - - // Gather destination tensors. - SmallVector dest; - if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) - return b.notifyMatchFailure(op, "failed to get destination tensors"); - - Operation *tiledOp = nullptr; - - SmallVector nonZeroNumThreads = - llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { - return !isConstantIntValue(ofr, 0); - })); - SmallVector materializedNonZeroNumThreads = - llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) { - return getValueOrCreateConstantIndexOp(b, loc, ofr); - })); - - // 2. Create the ForeachThreadOp with an empty region. - scf::ForeachThreadOp foreachThreadOp = b.create( - loc, identityTensor.value()->getResults(), - ValueRange(materializedNonZeroNumThreads), mapping); - - // 3. calculate the tile offsets and sizes. - b.setInsertionPointToStart(foreachThreadOp.getBody(0)); - SmallVector tiledOffsets, tiledSizes; - calculateTileOffsetsAndSizes( - b, loc, foreachThreadOp, numThreads, iterationDomain, - /*omitTileOffsetBoundsCheck =*/false, - /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes); - - // 4. Clone the tileable op and update its destination operands to use the - // output bbArgs of the ForeachThreadOp. - ArrayRef destBbArgs = - foreachThreadOp.getOutputBlockArguments(); - Operation *clonedOp = b.clone(*op.getOperation()); - auto destinationStyleOp = cast(clonedOp); - for (OpOperand *outOperand : destinationStyleOp.getDpsInitOperands()) { - auto *it = llvm::find(dest, outOperand->get()); - assert(it != dest.end() && "dest operand not found in dest"); - unsigned destNum = std::distance(dest.begin(), it); - SmallVector strides(numThreads.size(), b.getIndexAttr(1)); - SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); - SmallVector sizes = tiledSizes; - sizes[reductionDim] = b.getIndexAttr(1); - outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); - // TODO: use SubsetExtractOpInterface once it is available. - Value patial = b.create( - loc, outOperand->get().getType().cast(), - destBbArgs[destNum], outOffsets, sizes, strides); - outOperand->set(patial); - } - - // 5. Tile the cloned op and delete the clone. - SmallVector tiledOps = - cast(clonedOp).getTiledImplementation(b, tiledOffsets, - tiledSizes); - b.eraseOp(clonedOp); - assert(tiledOps.size() == 1 && "expected a single produced tiled op"); - tiledOp = tiledOps.front(); - - // 6. Insert the partial reductions back into a new tensor. - auto tiledInterfaceOp = dyn_cast(tiledOp); - assert(tiledInterfaceOp && "Tiled op does not implement TilingInterface"); - OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); - for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())), - tiledInterfaceOp->getResults(), destBbArgs)) { - b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); - SmallVector resultOffsets, resultSizes; - if (failed(tilingInterfaceOp.getResultTilePosition( - b, std::get<0>(it), tiledOffsets, tiledSizes, resultOffsets, - resultSizes))) - return op->emitOpError("output offsets couldn't be calculated"); - SmallVector resultOffsetsRank, resultSizesRank; - int64_t offIdx = 0; - int64_t sizeIdx = 0; - for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { - if (i == reductionDim) { - resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front()); - resultSizesRank.push_back(b.getIndexAttr(1)); - continue; - } - resultOffsetsRank.push_back(resultOffsets[offIdx++]); - resultSizesRank.push_back(resultSizes[sizeIdx++]); - } - - SmallVector strides(resultSizesRank.size(), - b.getIndexAttr(1)); - b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); - b.create(loc, std::get<1>(it), - std::get<2>(it), resultOffsetsRank, - resultSizesRank, strides); - } - // 7. Merge the partial reductions. - b.setInsertionPointAfter(foreachThreadOp); - Operation *mergeOp = - op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim); - b.replaceOp(op, mergeOp->getResults()); - ForeachThreadReductionTilingResult results; - results.initialOp = identityTensor.value(); - results.loops = foreachThreadOp; - results.parallelTiledOp = tiledOp; - results.mergeOp = mergeOp; - return results; -} - -// Insert a tile `source` into the destination tensor `dest`. The position at -// which the tile is inserted (as well as size of tile) is taken from a given -// ExtractSliceOp `sliceOp`. -static Value insertSliceIntoTensor(OpBuilder &b, Location loc, - tensor::ExtractSliceOp sliceOp, Value source, - Value dest) { - return b.create( - loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), - sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), - sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); -} - template static FailureOr tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef tileSizes, @@ -707,6 +562,165 @@ res, loops, outermostLoop ? outermostLoop->getResults() : tensorResults}; } +FailureOr +linalg::tileReductionUsingForeachThread(RewriterBase &b, + PartialReductionOpInterface op, + ArrayRef numThreads, + ArrayRef tileSizes, + Optional mapping) { + Location loc = op.getLoc(); + OpBuilder::InsertionGuard g(b); + // Ops implementing PartialReductionOpInterface are expected to implement + // TilingInterface. + auto tilingInterfaceOp = cast(op.getOperation()); + SmallVector iterationDomain = tilingInterfaceOp.getIterationDomain(b); + if (op->getNumResults() != 1) + return b.notifyMatchFailure( + op, "don't support ops with multiple results for now"); + SmallVector iterators = + tilingInterfaceOp.getLoopIteratorTypes(); + SmallVector redDims; + cast(op.getOperation()).getReductionDims(redDims); + if (redDims.size() != 1) + return b.notifyMatchFailure( + op, "only support ops with one reduction dimension."); + if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) + return b.notifyMatchFailure(op, "if tile sizes are present it must have as " + "many elements as number of threads"); + int reductionDim = static_cast(redDims.front()); + // 1. create the inital tensor value. + FailureOr identityTensor = + op.generateInitialTensorForPartialReduction(b, loc, numThreads, + reductionDim); + if (failed(identityTensor)) + return b.notifyMatchFailure(op, + "cannot create a tensor of identity value."); + + // Gather destination tensors. + SmallVector dest; + if (failed(tensor::getOrCreateDestinations(b, loc, op, dest))) + return b.notifyMatchFailure(op, "failed to get destination tensors"); + + Operation *tiledOp = nullptr; + + SmallVector nonZeroNumThreads = + llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) { + return !isConstantIntValue(ofr, 0); + })); + SmallVector materializedNonZeroNumThreads = + getAsValues(b, loc, nonZeroNumThreads); + + // 2. Create the ForeachThreadOp with an empty region. + scf::ForeachThreadOp foreachThreadOp = b.create( + loc, identityTensor.value()->getResults(), + ValueRange(materializedNonZeroNumThreads), mapping); + + // 3. calculate the tile offsets and sizes. + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + SmallVector tiledOffsets, tiledSizes; + calculateTileOffsetsAndSizes( + b, loc, foreachThreadOp, numThreads, iterationDomain, + /*omitTileOffsetBoundsCheck =*/false, + /*nominalTileSizes=*/std::nullopt, tiledOffsets, tiledSizes); + + // 4. Clone the tileable op and update its destination operands to use the + // output bbArgs of the ForeachThreadOp. + ArrayRef destBbArgs = + foreachThreadOp.getOutputBlockArguments(); + Operation *clonedOp = b.clone(*op.getOperation()); + b.setInsertionPointToStart(foreachThreadOp.getBody(0)); + auto destinationStyleOp = cast(clonedOp); + for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) { + auto *it = llvm::find(dest, initOperand->get()); + assert(it != dest.end() && "dest operand not found in dest"); + unsigned destNum = std::distance(dest.begin(), it); + SmallVector strides(numThreads.size(), b.getIndexAttr(1)); + SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); + SmallVector sizes = tiledSizes; + sizes[reductionDim] = b.getIndexAttr(1); + outOffsets[reductionDim] = foreachThreadOp.getThreadIndices().front(); + // TODO: use SubsetExtractOpInterface once it is available. + Value patial = b.create( + loc, initOperand->get().getType().cast(), + destBbArgs[destNum], outOffsets, sizes, strides); + initOperand->set(patial); + } + b.setInsertionPoint(clonedOp); + + // 5. Tile the cloned op and delete the clone. + if (tileSizes.empty()) { + SmallVector tiledOps = + cast(clonedOp).getTiledImplementation(b, tiledOffsets, + tiledSizes); + assert(tiledOps.size() == 1 && "expected a single produced tiled op"); + tiledOp = tiledOps.front(); + } else { + LinalgTilingOptions options; + auto tiled = tileLinalgOpImpl(b, cast(clonedOp), + tileSizes, options); + SmallVector ids = foreachThreadOp.getThreadIndices(); + mapLoopToProcessorIds(cast(tiled->loops.back()), ids, + materializedNonZeroNumThreads); + assert(tiled->loops.size() == 1 && "expected a single produced loop"); + tiledOp = tiled->loops.front(); + } + b.eraseOp(clonedOp); + + // 6. Insert the partial reductions back into a new tensor. + b.setInsertionPointAfter(tiledOp); + OpBuilder::InsertPoint insertPt = b.saveInsertionPoint(); + for (auto [index, result, bbArg] : + llvm::zip(llvm::seq(0, dest.size()), tiledOp->getResults(), + destBbArgs)) { + b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint()); + SmallVector resultOffsets, resultSizes; + if (failed(tilingInterfaceOp.getResultTilePosition( + b, index, tiledOffsets, tiledSizes, resultOffsets, resultSizes))) + return op->emitOpError("output offsets couldn't be calculated"); + SmallVector resultOffsetsRank, resultSizesRank; + int64_t offIdx = 0; + int64_t sizeIdx = 0; + for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { + if (i == reductionDim) { + resultOffsetsRank.push_back(foreachThreadOp.getThreadIndices().front()); + resultSizesRank.push_back(b.getIndexAttr(1)); + continue; + } + resultOffsetsRank.push_back(resultOffsets[offIdx++]); + resultSizesRank.push_back(resultSizes[sizeIdx++]); + } + + SmallVector strides(resultSizesRank.size(), + b.getIndexAttr(1)); + b.setInsertionPointToEnd(foreachThreadOp.getTerminator().getBody()); + b.create( + loc, result, bbArg, resultOffsetsRank, resultSizesRank, strides); + } + // 7. Merge the partial reductions. + b.setInsertionPointAfter(foreachThreadOp); + Operation *mergeOp = + op.mergeReductions(b, loc, foreachThreadOp->getResults(), reductionDim); + b.replaceOp(op, mergeOp->getResults()); + ForeachThreadReductionTilingResult results; + results.initialOp = identityTensor.value(); + results.loops = foreachThreadOp; + results.parallelTiledOp = tiledOp; + results.mergeOp = mergeOp; + return results; +} + +// Insert a tile `source` into the destination tensor `dest`. The position at +// which the tile is inserted (as well as size of tile) is taken from a given +// ExtractSliceOp `sliceOp`. +static Value insertSliceIntoTensor(OpBuilder &b, Location loc, + tensor::ExtractSliceOp sliceOp, Value source, + Value dest) { + return b.create( + loc, sliceOp.getSource().getType(), source, dest, sliceOp.getOffsets(), + sliceOp.getSizes(), sliceOp.getStrides(), sliceOp.getStaticOffsets(), + sliceOp.getStaticSizes(), sliceOp.getStaticStrides()); +} + template FailureOr static tileLinalgOpImpl( RewriterBase &b, LinalgOp op, const LinalgTilingOptions &options) { diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -126,9 +126,9 @@ // CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { -// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] -// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) -// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor +// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor // CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] // CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor // CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ET]][0] [%[[D0]]] [1] : tensor to tensor @@ -180,9 +180,9 @@ // CHECK: %[[E:.*]] = tensor.empty(%[[D3]], %[[D4]]) : tensor // CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor // CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { -// CHECK: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] -// CHECK: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) -// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor +// CHECK-DAG: %[[TS0:.+]] = affine.min #[[MAP0]](%[[IV]])[%[[D1]]] +// CHECK-DAG: %[[TS1:.+]] = affine.max #[[MAP1]](%[[TS0]]) +// CHECK-DAG: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, 0, %[[IV]]] [%[[D0]], %[[D2]], 1] [1, 1, 1] : tensor to tensor // CHECK: %[[TINDEX:.+]] = affine.apply #[[MAP2]](%[[IV]])[%[[D1]]] // CHECK: %[[INCHUNKA:.+]] = tensor.extract_slice %[[ARG0]][0, %[[TINDEX]]] [%[[D0]], %[[TS1]]] [1, 1] : tensor to tensor // CHECK: %[[INCHUNKB:.+]] = tensor.extract_slice %[[ARG1]][%[[TINDEX]], 0] [%[[TS1]], %[[D2]]] [1, 1] : tensor to tensor @@ -197,3 +197,68 @@ // CHECK: linalg.yield // CHECK: } -> tensor // CHECK: return %[[R]] : tensor + +// ----- + +func.func @reduction_tile_parallel_cyclic_dist( + %arg0: tensor, %out: tensor) -> tensor { + %red = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%arg0 : tensor) + outs(%out : tensor) { + ^bb0(%arg7: f32, %arg9: f32): + %1 = arith.mulf %arg7, %arg7 : f32 + %2 = arith.addf %1, %arg9 : f32 + linalg.yield %2 : f32 + } -> tensor + return %red : tensor +} + +transform.sequence failures(propagate) { +^bb0(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1, %2, %3 = transform.structured.tile_reduction_using_foreach_thread %0 { num_threads = [0, 5], tile_sizes = [0, 3] } +} + +// CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 3)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1) -> (d0)> + +// CHECK: func @reduction_tile_parallel_cyclic_dist(%[[ARG0:.+]]: tensor, %[[ARG1:.+]]: tensor +// CHECK-DAG: %[[I:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C5:.*]] = arith.constant 5 : index +// CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index +// CHECK-DAG: %[[D0:.*]] = tensor.dim %[[ARG0]], %[[C0]] : tensor +// CHECK-DAG: %[[D2:.*]] = tensor.dim %[[ARG1]], %[[C0]] : tensor +// CHECK: %[[E:.*]] = tensor.empty(%[[D2]]) : tensor +// CHECK: %[[F:.*]] = linalg.fill ins(%[[I]] : f32) outs(%[[E]] : tensor) -> tensor +// CHECK: %[[L:.*]] = scf.foreach_thread (%[[IV:.+]]) in (%[[C5]]) shared_outs(%[[ARG3:.+]] = %[[F]]) -> (tensor) { +// CHECK: %[[ET:.+]] = tensor.extract_slice %[[ARG3:.+]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor to tensor +// CHECK: %[[D1:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor +// CHECK: %[[LB:.+]] = affine.apply #[[MAP0]]()[%[[IV]]] +// CHECK: %[[CARRY:.+]] = scf.for %[[IV1:.+]] = %[[LB]] to %[[D1]] step %[[C15]] iter_args(%[[ACC:.+]] = %[[ET]]) -> (tensor) { +// CHECK: %[[TS0:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[D1]]] +// CHECK: %[[D3:.+]] = tensor.dim %[[ACC]], %[[C0]] : tensor +// CHECK: %[[INCHUNK:.+]] = tensor.extract_slice %[[ARG0]][0, %[[IV1]]] [%[[D0]], %[[TS0]]] [1, 1] : tensor to tensor +// CHECK: %[[TEMPEXT:.+]] = tensor.extract_slice %[[ACC]][0] [%[[D3]]] [1] : tensor to tensor +// CHECK: %[[PARTIAL:.+]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[INCHUNK]] : tensor) outs(%[[TEMPEXT]] : tensor) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: %[[INS:.+]] = tensor.insert_slice %[[PARTIAL]] into %[[ACC]][0] [%[[D3]]] [1] : tensor into tensor +// CHECK: scf.yield %[[INS]] : tensor +// CHECK: } +// CHECK: scf.foreach_thread.perform_concurrently { +// CHECK: tensor.parallel_insert_slice %[[CARRY]] into %[[ARG3]][0, %[[IV]]] [%[[D0]], 1] [1, 1] : tensor into tensor +// CHECK: } +// CHECK: } +// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP2]], #[[MAP3]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor) outs(%[[ARG1]] : tensor) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor +// CHECK: return %[[R]] : tensor