diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp @@ -255,15 +255,11 @@ ArrayRef reductionDims) const { auto linalgOp = cast(op); OpBuilder::InsertionGuard guard(b); - assert(reductionDims.size() == 1 && - "only support single reduction right now."); + if (linalgOp.hasBufferSemantics()) return op->emitOpError("expected operation to have tensor semantics"); // Insert the new parallel dimension based on the index of the reduction - // loop. This could be controlled by user for more flexibility. - int64_t insertSplitDimension = reductionDims[0]; - assert(sizes.size() >= static_cast(insertSplitDimension) && - "reduction dimension must be tiled"); + // loops. This could be controlled by user for more flexibility. SmallVector combinerOps; if (!matchReduction(linalgOp.getRegionOutputArgs(), 0, combinerOps) || @@ -276,18 +272,31 @@ return op->emitOpError( "Failed to get an identity value for the reduction operation."); - // Calculate the new shape, we insert the new dimension based on the index - // of the reduction dimension. - SmallVector newOutputShape; ArrayRef oldShape = linalgOp.getShape(linalgOp.getDpsInitOperand(0)); + + // Extend tile size vector to the rank of the output tensor. + SmallVector tileSizeVector = + getValueOrCreateConstantIndexOp(b, loc, sizes); + if (tileSizeVector.size() < oldShape.size()) { + auto zero = b.create(loc, 0); + tileSizeVector.append(oldShape.size() - tileSizeVector.size(), zero); + } + + // Calculate the new shape, we insert the new dimensions based on the index + // of the reduction dimensions. + SmallVector newOutputShape; SmallVector dynamicDims; - for (int64_t idx : llvm::seq(0, oldShape.size() + 1)) { - if (idx == insertSplitDimension) { + int64_t currReductionDims = 0; + DenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); + for (int64_t idx : + llvm::seq(0, oldShape.size() + reductionDims.size())) { + if (reductionDimsSet.contains(idx)) { dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape); + currReductionDims++; continue; } - int64_t oldIdx = idx < insertSplitDimension ? idx : idx - 1; + int64_t oldIdx = idx - currReductionDims; int64_t dim = oldShape[oldIdx]; newOutputShape.push_back(dim); if (ShapedType::isDynamic(dim)) @@ -310,21 +319,20 @@ ArrayRef reductionDims) const { OpBuilder::InsertionGuard guard(b); auto linalgOp = cast(op); - assert(reductionDims.size() == 1 && - "only support single reduction right now."); - int64_t insertSplitDimension = reductionDims[0]; AffineMap oldOutputMap = linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(0)); - SmallVector outputExpr; - for (auto [idx, expr] : llvm::enumerate(oldOutputMap.getResults())) { - if (static_cast(idx) == insertSplitDimension) { - outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); - } - outputExpr.push_back(expr); + SmallVector outputExpr(oldOutputMap.getNumResults() + + reductionDims.size()); + + for (int idx : reductionDims) + outputExpr[idx] = b.getAffineDimExpr(idx); + int currExpr = 0; + for (int idx : llvm::seq(0, outputExpr.size())) { + if (outputExpr[idx]) + continue; + outputExpr[idx] = oldOutputMap.getResult(currExpr++); } - if (insertSplitDimension == oldOutputMap.getNumResults()) - outputExpr.push_back(b.getAffineDimExpr(reductionDims[0])); // Step 1: Extract a slice of the input operands. SmallVector valuesToTile = linalgOp.getDpsInputOperands(); @@ -338,11 +346,12 @@ Value out = b.create(loc, init[0], outOffsets, sizes, strides); - // Step3. create a generic op where the reduction dimension is replaced by a - // parallel dimension of the size of reduction. + // Step3. Create a generic op where the reduction dimensions are replaced + // by a parallel dimension of the size of reduction. SmallVector newIteratorTypes = linalgOp.getIteratorTypesArray(); - newIteratorTypes[reductionDims[0]] = utils::IteratorType::parallel; + for (int dim : reductionDims) + newIteratorTypes[dim] = utils::IteratorType::parallel; SmallVector newMaps = linalgOp.getIndexingMapsArray(); newMaps.back() = AffineMap::get(newMaps.back().getNumDims(), 0, outputExpr, linalgOp.getContext()); @@ -359,24 +368,25 @@ ValueRange partialReduce, ArrayRef reductionDims) const { auto linalgOp = cast(op); - assert(reductionDims.size() == 1 && - "only support single reduction right now."); - int64_t dimToMerge = reductionDims[0]; - // Then create a new reduction that only reduce the newly added dimension + DenseSet reductionDimsSet(reductionDims.begin(), reductionDims.end()); + + // Then create a new reduction that only reduce the newly added dimensions // from the previous op. int64_t intermRank = cast(partialReduce[0].getType()).getRank(); AffineMap inputMap = b.getMultiDimIdentityMap(intermRank); SmallVector reductionIteratorTypes; SmallVector exprs; + for (int64_t i : llvm::seq(0, intermRank)) { - if (dimToMerge == i) { + if (reductionDimsSet.contains(i)) { reductionIteratorTypes.push_back(utils::IteratorType::reduction); } else { exprs.push_back(b.getAffineDimExpr(i)); reductionIteratorTypes.push_back(utils::IteratorType::parallel); } } + AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op->getContext()); SmallVector reductionMaps = {inputMap, outputMap}; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -419,26 +419,18 @@ op, "don't support ops with multiple results for now"); SmallVector iterators = tilingInterfaceOp.getLoopIteratorTypes(); - int64_t numReductionDims = llvm::count( - tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); - if (numReductionDims != 1) - return b.notifyMatchFailure( - op, "only support ops with one reduction dimension."); - int reductionDim; + + SmallVector reductionDims; for (auto [idx, iteratorType] : llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { - if (iteratorType == utils::IteratorType::reduction) { - reductionDim = idx; - break; - } + if (iteratorType == utils::IteratorType::reduction) + reductionDims.push_back(idx); } - if (static_cast(reductionDim) >= tileSize.size()) - return b.notifyMatchFailure(op, "reduction dimension must be tiled"); // 1. create the inital tensor value. FailureOr identityTensor = op.generateInitialTensorForPartialReduction(b, loc, tileSize, - reductionDim); + reductionDims); if (failed(identityTensor)) return b.notifyMatchFailure(op, "cannot create a tensor of identity value."); @@ -450,7 +442,7 @@ // 3. Generate the tiled implementation within the inner most loop. b.setInsertionPoint(loops.back().getBody()->getTerminator()); Operation *parallelOp = op.tileToPartialReduction( - b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim); + b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDims); SmallVector resultSizesList; for (size_t i = 0; i < offsets.size(); i++) @@ -472,7 +464,7 @@ // 4. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); + Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); b.replaceOp(op, mergeOp->getResults()); SCFReductionTilingResult results; diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -353,3 +353,35 @@ %for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 5] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } } + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_tile_multiple_reduction(%arg0: tensor<86x128xf32>, %arg1: tensor<4096x86x128xf32>, %arg2: tensor<4096xf32>) -> tensor<4096xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<86x128xf32>, tensor<4096x86x128xf32>) outs(%arg2 : tensor<4096xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4096xf32> + return %0 : tensor<4096xf32> + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %for_op, %fill_op, %split_linalg_op, %combining_linalg_op = transform.structured.tile_reduction_using_scf %0 by tile_sizes = [0, 2, 64] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + } +} + +// CHECK: func @reduction_tile_multiple_reduction(%[[ARG0:.+]]: tensor<86x128xf32>, %[[ARG1:.+]]: tensor<4096x86x128xf32>, %[[ARG2:.+]]: tensor<4096xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4096x2x64xf32>) -> tensor<4096x2x64xf32> +// CHECK: %[[L0:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG3:.*]] = %[[F]]) -> (tensor<4096x2x64xf32>) +// CHECK: %[[L1:.*]] = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%[[ARG4:.*]] = %[[ARG3]]) -> (tensor<4096x2x64xf32>) +// CHECK: %[[OUT:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}, {{.*}}], iterator_types = ["parallel", "parallel", "parallel"]} ins(%{{.*}}, %{{.*}}: tensor<2x64xf32>, tensor<4096x2x64xf32>) outs(%{{.*}}: tensor<4096x2x64xf32>) +// CHECK: scf.yield %[[OUT]] : tensor<4096x2x64xf32> +// CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32> +// CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>) +// CHECK: return %[[OUT2]] : tensor<4096xf32>