diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2030,7 +2030,7 @@ ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::tensor::InsertSliceOp target, + ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -38,6 +38,7 @@ #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; using namespace mlir::linalg; @@ -3214,18 +3215,26 @@ //===----------------------------------------------------------------------===// // InsertSliceToCopyOp //===----------------------------------------------------------------------===// +template +DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + static_assert((std::is_same_v || + std::is_same_v) && + "wrong op type"); -DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( - tensor::InsertSliceOp target, transform::ApplyToEachResultList &results, - transform::TransformState &state) { - if (auto copySource = target.getSource().getDefiningOp()) { + if (auto copySource = + target.getSource().template getDefiningOp()) { results.push_back(copySource); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(target->getContext(), &listener); - rewriter.setInsertionPoint(target); + // If we are inside an InParallel region, + if (std::is_same_v) { + rewriter.setInsertionPoint( + target->template getParentOfType()); + } + Value extracted = rewriter.create( target.getLoc(), target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); @@ -3233,7 +3242,9 @@ .create(target.getLoc(), target.getSource(), extracted) .getResult(0); - rewriter.replaceOpWithNewOp( + // Reset the insertion point. + rewriter.setInsertionPoint(target); + rewriter.replaceOpWithNewOp( target, copied, target.getDest(), target.getMixedOffsets(), target.getMixedSizes(), target.getMixedStrides()); @@ -3241,6 +3252,25 @@ return DiagnosedSilenceableFailure::success(); } +DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( + Operation *targetOp, transform::ApplyToEachResultList &results, + transform::TransformState &state) { + + TrackingListener listener(state, *this); + IRRewriter rewriter(targetOp->getContext(), &listener); + rewriter.setInsertionPoint(targetOp); + if (auto target = dyn_cast_or_null(targetOp)) + return doit(rewriter, target, results, state); + if (auto target = dyn_cast_or_null(targetOp)) + return doit(rewriter, target, results, state); + + DiagnosedSilenceableFailure diag = + emitSilenceableError() + << "only InsertSliceOp and ParallelInsertSliceOp ops are supported"; + diag.attachNote(targetOp->getLoc()) << "target op"; + return diag; +} + //===----------------------------------------------------------------------===// // Transform op registration //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir --- a/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-insert-slice-to-copy.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file | FileCheck %s +// RUN: mlir-opt -test-transform-dialect-interpreter %s --split-input-file --allow-unregistered-dialect | FileCheck %s // CHECK-LABEL: func @insert_slice_to_copy // CHECK-SAME: %[[I:.*]]: tensor<2x3xf32> @@ -108,3 +108,30 @@ transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy"> } +// ----- + +// CHECK-LABEL: func @parallel_insert_slice_to_copy +func.func @parallel_insert_slice_to_copy(%out : tensor, %sz0: index, %sz1: index) { + %0 = scf.forall (%arg0, %arg1) in (27, 8) shared_outs(%arg2 = %out) -> (tensor) { + %t = "make_me_a_tensor"() : () -> (tensor ) + + // CHECK: tensor.extract_slice + // CHECK: linalg.copy + // CHECK: scf.forall.in_parallel + // CHECK: tensor.parallel_insert_slice + scf.forall.in_parallel { + tensor.parallel_insert_slice %t into %arg2[0, 0] [%sz0, %sz1] [1, 1] + : tensor into tensor + } + } + return +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !transform.any_op): + %0 = transform.structured.match ops{["tensor.parallel_insert_slice"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.insert_slice_to_copy %0 + : (!transform.any_op) -> !transform.any_op + transform.cast %1 : !transform.any_op to !transform.op<"linalg.copy"> +}