diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.h @@ -16,6 +16,9 @@ namespace mlir { class DialectRegistry; +namespace scf { +class ForallOp; +} // namespace scf namespace tensor { void registerTransformDialectExtension(DialectRegistry ®istry); diff --git a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td --- a/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td +++ b/mlir/include/mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td @@ -157,4 +157,72 @@ }]; } +def ShareForallOperandsOp : Op { + let description = [{ + Target a single scf.forall op and shares all uses of the specified + `share_operands` operand indices. + + Sharing can be thought of as the inverse of traditional privatization. + Privatization consists in determining that a part of memory is only accessed + by a single thread to and subsequently slicing out that part into a + thread_private storage that has smaller footprint, better locality and better + alignment properties. + In the case of scf.forall on tensors, tensor values are immutable + and the same tensor value may be passed as `shared_outs` and also captured + for internal uses. + Due to the immutability property, the whole tensor values are private by + construction and result in alloc + copy of the whole tensor on every thread + to maintain the original SSA value after bufferizing. + + An analysis similar to privatization is needed to ensure that only a private + slice is needed and that the whole tensor can be shared. + This transformation amounts to injecting the result of such an analysis as + static information in the program. + The transformation checks that the values captured are `tensor.extract_slice` + with a matching `tensor.parallel_insert_slice`, to approximate the lack of + a cross-thread dependence analysis. + However this can still be unsafe wrt parallelism so use carefully! + + Sharing consists in rewriting all uses of the operands passed as + `shared_outs` that are also captured wihtin the `scf.forall` region + into the matching `shared_outs` bbarg. + + Only those operands whose indices are specified in `share_operands` are + shared. An empty `share_operands` specification considers all operands to + be shared. + + #### Return modes + + If any of the `share_operands` indices overflow, a definite error is produced. + + If a `share_operands` fails a sharing precondition, it is ignored. + In the future, we should emit a notification. + + This transform consumes the target handle and produces a result handle to + the modified `scf.forall` op. + }]; + + let arguments = ( + ins TransformHandleTypeInterface:$forall_op, + DefaultValuedOptionalAttr:$share_operands + ); + let results = (outs TransformHandleTypeInterface:$result); + + let assemblyFormat = [{ + $forall_op (`share_operands` `=` $share_operands^ )? attr-dict + `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::scf::ForallOp forallOp, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // TENSOR_TRANSFORM_OPS diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -169,6 +169,77 @@ return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ShareForallOperandsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ShareForallOperandsOp::applyToOne( + transform::TransformRewriter &rewriter, scf::ForallOp forallOp, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + SmallVector shareOperands(getShareOperands()); + // Empty case: consider all operands need to be shared. + if (shareOperands.empty()) { + shareOperands = + llvm::to_vector(llvm::seq(0, forallOp.getOutputs().size())); + } + for (int64_t outputIdx : getShareOperands()) { + if (outputIdx < 0 || outputIdx >= forallOp.getOutputs().size()) + return mlir::emitDefiniteFailure(forallOp, "operand idx overflow"); + Value toShare = forallOp.getOutputs()[outputIdx]; + if (std::distance(toShare.getUses().begin(), toShare.getUses().end()) != + 2) { + /*return mlir::emitSilenceableFailure( + forallOp, + "operand to share must have exactly 2 uses, the forall op " + "and an extract_slice op.");*/ + continue; + } + tensor::ExtractSliceOp extractSliceOp; + for (Operation *user : toShare.getUsers()) { + extractSliceOp = dyn_cast(user); + if (extractSliceOp) + break; + } + if (!extractSliceOp) { + /*return mlir::emitSilenceableFailure( + forallOp, + "shared operands use must be extractSliceOp.");*/ + continue; + } + // Get the corresponding bbArg. + BlockArgument bbArg = forallOp.getOutputBlockArguments()[outputIdx]; + + // Check if the extract_slice has a matching parallel_insert_slice + // (i.e., same source/target, offsets, sizes and strides). + auto isMatchingParallelInsertSlice = [&](Operation &op) { + auto insertSlice = dyn_cast(&op); + if (!insertSlice) + return false; + if (insertSlice.getDest() != bbArg) + return false; + return llvm::equal(insertSlice.getMixedOffsets(), + extractSliceOp.getMixedOffsets()) && + llvm::equal(insertSlice.getMixedSizes(), + extractSliceOp.getMixedSizes()) && + llvm::equal(insertSlice.getMixedStrides(), + extractSliceOp.getMixedStrides()); + }; + if (llvm::none_of(forallOp.getTerminator().getYieldingOps(), + isMatchingParallelInsertSlice)) { + continue; + } + + // Promote extract_slice source to bbArg. + rewriter.updateRootInPlace(extractSliceOp, [&]() { + extractSliceOp.getSourceMutable().assign(bbArg); + }); + } + + results.push_back(forallOp); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Tensor/tiling.mlir b/mlir/test/Dialect/Tensor/tiling.mlir --- a/mlir/test/Dialect/Tensor/tiling.mlir +++ b/mlir/test/Dialect/Tensor/tiling.mlir @@ -645,3 +645,44 @@ %0 = transform.structured.match ops{["tensor.pack"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1, %loops:4 = transform.structured.tile_to_scf_for %0 [1, 1, 1, 1] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) } + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map2 = affine_map<(d0) -> (d0 * 4)> + +// CHECK-LABEL: @promote +func.func @promote() -> (tensor<16x128xf32>) { + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.000000e+00 : f32 + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + + %empty = tensor.empty() : tensor<16x128xf32> + %filled = linalg.fill ins(%f0 : f32) outs(%empty : tensor<16x128xf32>) -> tensor<16x128xf32> + + // CHECK: forall{{.*}}shared_outs(%[[ARG:.*]] = + // CHECK: %[[A:.*]] = tensor.extract_slice %[[ARG]] + // CHECK: %[[C:.*]] = linalg.generic{{.*}}ins(%[[A]]{{.*}}outs(%[[A]] + %10 = scf.forall (%arg0, %arg1) in (%c16, %c32) shared_outs(%arg2 = %filled) -> (tensor<16x128xf32>) { + %11 = affine.apply #map2(%arg1) + %extracted_slice = tensor.extract_slice %filled[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %extracted_slice_2 = tensor.extract_slice %arg2[%arg0, %11] [1, 4] [1, 1] : tensor<16x128xf32> to tensor<1x4xf32> + %13 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} ins(%extracted_slice : tensor<1x4xf32>) outs(%extracted_slice_2 : tensor<1x4xf32>) { + ^bb0(%in: f32, %out: f32): + %res = arith.addf %in, %in: f32 + linalg.yield %res : f32 + } -> tensor<1x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %13 into %arg2[%arg0, %11] [1, 4] [1, 1] : tensor<1x4xf32> into tensor<16x128xf32> + } + } + return %10 : tensor<16x128xf32> +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["scf.forall"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.cast %0 : !transform.any_op to !transform.op<"scf.forall"> + transform.share_forall_operands %1 share_operands = [0] : (!transform.op<"scf.forall">) -> !transform.op<"scf.forall"> +}