diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -643,24 +643,32 @@ SmallVector iterators = tilingInterfaceOp.getLoopIteratorTypes(); - SmallVector redDims; - linalgOp.getReductionDims(redDims); - if (redDims.size() != 1) - return b.notifyMatchFailure( - op, "only support ops with one reduction dimension."); if (!tileSizes.empty() && tileSizes.size() != numThreads.size()) return b.notifyMatchFailure(op, "if tile sizes are present it must have as " "many elements as number of threads"); - int reductionDim = static_cast(redDims.front()); - if (redDims.front() >= numThreads.size()) + SmallVector tiledReductionDims, reductionInductionVarIndices; + int64_t nonZeroTileIdx = 0; + for (auto [idx, iteratorType] : + llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { + bool isNonZeroTileSize = + idx < numThreads.size() && !isConstantIntValue(numThreads[idx], 0); + if (iteratorType == utils::IteratorType::reduction && isNonZeroTileSize) { + tiledReductionDims.push_back(idx); + reductionInductionVarIndices.push_back(nonZeroTileIdx); + } + nonZeroTileIdx += isNonZeroTileSize; + } + + if (tiledReductionDims.empty()) { return b.notifyMatchFailure( - op, "reduction dimension must be mapped to threads"); + op, "at least one reduction dimension must be mapped to threads"); + } // 1. Create the inital tensor value. FailureOr identityTensor = op.generateInitialTensorForPartialReduction(b, loc, numThreads, - reductionDim); + tiledReductionDims); if (failed(identityTensor)) return b.notifyMatchFailure(op, "cannot create a tensor of identity value."); @@ -683,6 +691,7 @@ scf::ForallOp forallOp = b.create( loc, getAsOpFoldResult(materializedNonZeroNumThreads), (*identityTensor)->getResults(), mapping); + SmallVector threadIds = forallOp.getInductionVars(); // 3. Calculate the tile offsets and sizes for the subsequent loop that will // be nested under `forallOp`. @@ -701,6 +710,9 @@ OpBuilder::InsertionGuard g(b); b.setInsertionPoint(forallOp.getTerminator()); + llvm::SmallDenseSet reductionIndexSet(tiledReductionDims.begin(), + tiledReductionDims.end()); + SmallVector tiledDpsInitOperands; for (OpOperand *initOperand : destinationStyleOp.getDpsInitOperands()) { auto *it = llvm::find(dest, initOperand->get()); @@ -709,9 +721,23 @@ SmallVector strides(numThreads.size(), b.getIndexAttr(1)); SmallVector outOffsets(numThreads.size(), b.getIndexAttr(0)); - SmallVector sizes = tiledSizes; - sizes[reductionDim] = b.getIndexAttr(1); - outOffsets[reductionDim] = forallOp.getInductionVars().front(); + SmallVector sizes(tiledSizes.begin(), + tiledSizes.begin() + numThreads.size()); + for (auto [indVarIdx, redDim] : + llvm::zip_equal(reductionInductionVarIndices, tiledReductionDims)) { + sizes[redDim] = b.getIndexAttr(1); + outOffsets[redDim] = threadIds[indVarIdx]; + } + // Here we are just slicing along tiled reduction dimensions + // so that the shape of the output of the cloned op matches + // that of the original op. This enables generating the tiled + // implementation in the next step, which includes parallel dimension + // tiling. + for (int i = 0, e = numThreads.size(); i < e; ++i) { + if (!reductionIndexSet.contains(i) && + !isConstantIntValue(numThreads[i], 0)) + sizes[i] = tensor::getMixedSize(b, loc, destBbArgs[destNum], i); + } // TODO: use SubsetExtractOpInterface once it is available. tiledDpsInitOperands.push_back(b.create( loc, cast(initOperand->get().getType()), @@ -750,9 +776,8 @@ if (failed(maybeTiled)) return b.notifyMatchFailure(op, "failed tileLinalgOpImpl"); - SmallVector ids = forallOp.getInductionVars(); - mapLoopToProcessorIds(cast(maybeTiled->loops.back()), ids, - materializedNonZeroNumThreads); + mapLoopToProcessorIds(cast(maybeTiled->loops.back()), + threadIds, materializedNonZeroNumThreads); if (maybeTiled->loops.size() != 1) { return clonedOp->emitError("expected a single produced loop"); } @@ -777,9 +802,11 @@ SmallVector resultOffsetsRank, resultSizesRank; int64_t offIdx = 0; int64_t sizeIdx = 0; + int64_t reductionIdx = 0; for (int64_t i = 0, e = numThreads.size(); i < e; ++i) { - if (i == reductionDim) { - resultOffsetsRank.push_back(forallOp.getInductionVars().front()); + if (tiledReductionDims[reductionIdx] == i) { + resultOffsetsRank.push_back( + threadIds[reductionInductionVarIndices[reductionIdx++]]); resultSizesRank.push_back(b.getIndexAttr(1)); continue; } @@ -799,7 +826,7 @@ // 7. Merge the partial reductions. b.setInsertionPointAfter(forallOp); Operation *mergeOp = - op.mergeReductions(b, loc, forallOp->getResults(), reductionDim); + op.mergeReductions(b, loc, forallOp->getResults(), tiledReductionDims); b.replaceOp(op, mergeOp->getResults()); // 8. Return. diff --git a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir --- a/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir +++ b/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir @@ -385,3 +385,121 @@ // CHECK: scf.yield %[[L1]] : tensor<4096x2x64xf32> // CHECK: %[[OUT2:.*]] = linalg.generic {indexing_maps = [{{.*}}, {{.*}}], iterator_types = ["parallel", "reduction", "reduction"]} ins(%{{.*}} : tensor<4096x2x64xf32>) outs(%{{.*}} : tensor<4096xf32>) // CHECK: return %[[OUT2]] : tensor<4096xf32> + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_tile_multiple_reduction_parallel(%arg0: tensor<32x128xf32>, %arg1: tensor<4x32x128xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<32x128xf32>, tensor<4x32x128xf32>) outs(%arg2 : tensor<4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %0 : tensor<4xf32> + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_forall %0 by num_threads = [4, 2], tile_sizes = [], mapping = [#gpu.thread, #gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + } +} + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-DAG: #[[MAP5:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK: func @reduction_tile_multiple_reduction_parallel(%[[ARG0:.+]]: tensor<32x128xf32>, %[[ARG1:.+]]: tensor<4x32x128xf32>, %[[ARG2:.+]]: tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4x2xf32>) -> tensor<4x2xf32> +// CHECK: %[[L:.*]] = scf.forall (%[[Z:.+]], %[[Y:.+]]) in (4, 2) shared_outs(%[[ARG5:.+]] = %[[F]]) -> (tensor<4x2xf32>) { +// CHECK: %[[ER:.+]] = tensor.extract_slice %[[ARG5]][0, %[[Y]]] [4, 1] [1, 1] : tensor<4x2xf32> to tensor<4xf32> +// CHECK: %[[OFF:.+]] = affine.apply #[[MAP]](%[[Y]]) +// CHECK: %[[ESIN0:.+]] = tensor.extract_slice %[[ARG0]][%[[OFF]], 0] [16, 128] [1, 1] : tensor<32x128xf32> to tensor<16x128xf32> +// CHECK: %[[ESIN1:.+]] = tensor.extract_slice %[[ARG1]][%[[Z]], %[[OFF]], 0] [1, 16, 128] [1, 1, 1] : tensor<4x32x128xf32> to tensor<1x16x128xf32> +// CHECK: %[[EP:.+]] = tensor.extract_slice %[[ER]][%[[Z]]] [1] [1] : tensor<4xf32> to tensor<1xf32> +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]], #[[MAP3]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[ESIN0]], %[[ESIN1]] : tensor<16x128xf32>, tensor<1x16x128xf32>) +// CHECK-SAME: outs(%[[EP]] : tensor<1xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<1xf32> +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG5]][%[[Z]], %[[Y]]] [1, 1] [1, 1] : tensor<1xf32> into tensor<4x2xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} +// CHECK: %[[R:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP4]], #[[MAP5]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction"] +// CHECK-SAME: ins(%[[L]] : tensor<4x2xf32>) +// CHECK-SAME: outs(%[[ARG2]] : tensor<4xf32>) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: return %[[R]] : tensor<4xf32> + + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0)> +module { + func.func @reduction_tile_multiple_reduction_parallel_all_dims(%arg0: tensor<32x128xf32>, %arg1: tensor<4x32x128xf32>, %arg2: tensor<4xf32>) -> tensor<4xf32> { + %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "reduction", "reduction"]} ins(%arg0, %arg1 : tensor<32x128xf32>, tensor<4x32x128xf32>) outs(%arg2 : tensor<4xf32>) { + ^bb0(%in: f32, %in_0: f32, %out: f32): + %1 = arith.mulf %in, %in_0 : f32 + %2 = arith.addf %1, %out : f32 + linalg.yield %2 : f32 + } -> tensor<4xf32> + return %0 : tensor<4xf32> + } + transform.sequence failures(propagate) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %loop, %1, %2, %3 = transform.structured.tile_reduction_using_forall %0 by num_threads = [0, 2, 4], tile_sizes = [], mapping = [#gpu.thread, #gpu.thread] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + } +} + +// CHECK-DAG: #[[MAP:.*]] = affine_map<(d0) -> (d0 * 16)> +// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0) -> (d0 * 32)> +// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d1, d2)> +// CHECK-DAG: #[[MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP4:.*]] = affine_map<(d0, d1, d2) -> (d0)> +// CHECK: func @reduction_tile_multiple_reduction_parallel_all_dims(%[[ARG0:.+]]: tensor<32x128xf32>, %[[ARG1:.+]]: tensor<4x32x128xf32>, %[[ARG2:.+]]: tensor<4xf32> +// CHECK: %[[F:.*]] = linalg.fill ins(%{{.*}} : f32) outs(%{{.*}} : tensor<4x2x4xf32>) -> tensor<4x2x4xf32> +// CHECK: %[[L:.*]] = scf.forall (%[[Z:.+]], %[[Y:.+]]) in (2, 4) shared_outs(%[[ARG5:.+]] = %[[F]]) -> (tensor<4x2x4xf32>) { +// CHECK: %[[ER:.+]] = tensor.extract_slice %[[ARG5]][0, %[[Z]], %[[Y]]] [4, 1, 1] [1, 1, 1] : tensor<4x2x4xf32> to tensor<4xf32> +// CHECK: %[[OFFZ:.+]] = affine.apply #[[MAP]](%[[Z]]) +// CHECK: %[[OFFY:.+]] = affine.apply #[[MAP1]](%[[Y]]) +// CHECK: %[[ESIN0:.+]] = tensor.extract_slice %[[ARG0]][%[[OFFZ]], %[[OFFY]]] [16, 32] [1, 1] : tensor<32x128xf32> to tensor<16x32xf32> +// CHECK: %[[ESIN1:.+]] = tensor.extract_slice %[[ARG1]][0, %[[OFFZ]], %[[OFFY]]] [4, 16, 32] [1, 1, 1] : tensor<4x32x128xf32> to tensor<4x16x32xf32> +// CHECK: %[[PARTIAL:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[ESIN0]], %[[ESIN1]] : tensor<16x32xf32>, tensor<4x16x32xf32>) +// CHECK-SAME: outs(%[[ER]] : tensor<4xf32>) { +// CHECK: arith.mulf +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[PARTIAL]] into %[[ARG5]][0, %[[Z]], %[[Y]]] [4, 1, 1] [1, 1, 1] : tensor<4xf32> into tensor<4x2x4xf32> +// CHECK: } +// CHECK: } {mapping = [#gpu.thread, #gpu.thread]} +// CHECK: %[[R:.*]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP3]], #[[MAP4]]] +// CHECK-SAME: iterator_types = ["parallel", "reduction", "reduction"] +// CHECK-SAME: ins(%[[L]] : tensor<4x2x4xf32>) +// CHECK-SAME: outs(%[[ARG2]] : tensor<4xf32>) { +// CHECK: arith.addf +// CHECK: linalg.yield +// CHECK: } -> tensor<4xf32> +// CHECK: return %[[R]] : tensor<4xf32>